package di import ( "crypto/tls" "errors" "fmt" "net/http" "strings" "testing" ) func ExampleBuild() { type Username string type Config struct { User Username Age int Greeting func(Username, int) string } cfg := Config{ User: "Alice", Age: 42, Greeting: func(u Username, age int) string { return fmt.Sprintf("Hello, %s. You've been around the sun %d times!", string(u), age) }, } type Result struct { Greeting string } var res Result err := Build(cfg, &res) if err != nil { fmt.Println(err) return } fmt.Println(res.Greeting) // Output: Hello, Alice. You've been around the sun 42 times! } func ExampleNew() { type Username string type Config struct { User Username Age int Greeting func(Username, int) string } cfg := Config{ User: "Alice", Age: 42, Greeting: func(u Username, age int) string { return fmt.Sprintf("Hello, %s. You've been around the sun %d times!", string(u), age) }, } greeting, err := New[string](cfg) if err != nil { fmt.Println(err) return } fmt.Println(greeting) // Output: Hello, Alice. You've been around the sun 42 times! } func ExampleOptional() { type Config struct { // A *tls.Config type or Provide[*tls.Config] also works, but using the // Optional wrapper lets us convey the optionality explicitly TLSConfig Optional[*tls.Config] Server Provide[*http.Server] } cfg := Config{ Server: MustProvide[*http.Server](func( tlsConf Optional[*tls.Config], ) *http.Server { s := &http.Server{ Addr: ":8080", } if tlsConf.IsSome { s.TLSConfig = tlsConf.Val } return s }), } server, err := New[*http.Server](cfg) if err != nil { fmt.Println(err) return } if server.TLSConfig == nil { fmt.Println("No TLS configuration was provided") } // Output: No TLS configuration was provided } func ExampleSideEffect() { type Config struct { Server *http.Server SideEffects []Provide[SideEffect] } type Result struct { StartedServer *http.Server _ []SideEffect } res, err := New[Result](&Config{ Server: &http.Server{ Addr: ":8080", }, SideEffects: []Provide[SideEffect]{ MustProvide[SideEffect](func() SideEffect { fmt.Println("Starting server...") go http.ListenAndServe(":8080", nil) return SideEffect{} }), }, }) if err != nil { fmt.Println(err) return } defer res.StartedServer.Close() // Output: Starting server... } func TestBuildSuccess(t *testing.T) { type A struct { val string } type C struct { val int } type B struct { a *A cs []C } type NestedConfig struct { OtherSetting bool NestedDecision func(c NestedConfig) uint } type ANum int type ConfigWithInner struct { NestedConfig SomeSetting bool Inner func(c ConfigWithInner) int } tests := []struct { name string config any result any verify func(t *testing.T, result any) }{ { name: "complex dependencies with providers", config: struct { MakeA Provide[*A] MakeB Provide[*B] MakeCs []Provide[C] }{ MakeA: MustProvide[*A](func() (*A, error) { return &A{val: "hello"}, nil }), MakeB: MustProvide[*B](func(a *A, cs []C) *B { return &B{a: a, cs: cs} }), MakeCs: []Provide[C]{ MustProvide[C](C{val: 1}), MustProvide[C](func() (C, error) { return C{val: 2}, nil }), }, }, result: &struct { A *A B *B }{}, verify: func(t *testing.T, result any) { res := result.(*struct { A *A B *B }) if res.A == nil { t.Fatalf("expected res.A to be populated") } if res.B == nil { t.Fatalf("expected res.B to be populated") } if res.B.a != res.A { t.Fatalf("expected B.a to reference A instance") } if len(res.B.cs) != 2 { t.Fatalf("wrong count. Saw %d", len(res.B.cs)) } if res.B.cs[0].val != 1 { t.Fatalf("wrong value") } if res.B.cs[1].val != 2 { t.Fatalf("wrong value") } if res.A.val != "hello" { t.Fatalf("unexpected A value: %s", res.A.val) } }, }, { name: "simple function constructors", config: struct { MakeA func() (*A, error) MakeB func(*A) (*B, error) }{ MakeA: func() (*A, error) { return &A{val: "hello"}, nil }, MakeB: func(a *A) (*B, error) { panic("Unexpected call to MakeB") }, }, result: &struct { A *A }{}, verify: func(t *testing.T, result any) { res := result.(*struct { A *A }) if res.A == nil { t.Fatalf("expected res.A to be populated") } if res.A.val != "hello" { t.Fatalf("unexpected A value: %s", res.A.val) } }, }, { name: "pre-supplied values", config: struct { A *A MB func(*A) (*B, error) }{ A: &A{val: "pre-supplied"}, MB: func(a *A) (*B, error) { return &B{a: a}, nil }, }, result: &struct { A *A B *B }{}, verify: func(t *testing.T, result any) { res := result.(*struct { A *A B *B }) if res.A == nil || res.A.val != "pre-supplied" { t.Fatalf("expected pre-supplied A, got %+v", res.A) } if res.B == nil || res.B.a != res.A { t.Fatalf("expected B referencing A, got %+v", res.B) } }, }, { name: "pre-supplied nil values", config: struct { A *A }{ A: nil, }, result: &struct { A *A }{}, verify: func(t *testing.T, result any) { res := result.(*struct { A *A }) if res.A != nil { t.Fatalf("expected nil A, got %+v", res.A) } }, }, { name: "Explicit Optional Value", config: struct { A Optional[*A] }{}, result: &struct { A Optional[*A] }{}, verify: func(t *testing.T, result any) { res := result.(*struct { A Optional[*A] }) if res.A.IsSome { t.Fatalf("expected none") } }, }, { name: "Explicit Provided Optional Value", config: struct { A Optional[*A] }{A: Some(&A{})}, result: &struct { A Optional[*A] }{}, verify: func(t *testing.T, result any) { res := result.(*struct { A Optional[*A] }) _ = res.A.Unwrap() }, }, { name: "type aliases", config: struct { A ANum B int }{A: 3, B: 4}, result: &struct { A ANum }{}, verify: func(t *testing.T, result any) { res := result.(*struct { A ANum }) if res.A != 3 { t.Fatalf("expected A=3, got %v", res.A) } }, }, { name: "reference config in constructors", config: ConfigWithInner{ SomeSetting: true, Inner: func(c ConfigWithInner) int { if c.SomeSetting { return 1 } return 0 }, NestedConfig: NestedConfig{ OtherSetting: true, NestedDecision: func(c NestedConfig) uint { if c.OtherSetting { return 1 } return 0 }, }, }, result: &struct { A int B uint }{}, verify: func(t *testing.T, result any) { res := result.(*struct { A int B uint }) if res.A != 1 { t.Fatalf("expected A=1, got %v", res.A) } if res.B != 1 { t.Fatalf("expected B=1, got %v", res.B) } }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := Build(tt.config, tt.result) if err != nil { t.Fatalf("Build failed: %v", err) } tt.verify(t, tt.result) }) } } func TestBuildErrors(t *testing.T) { type A struct{} type B struct { a *A } type X struct{} type Y struct{} sentinel := errors.New("boom") tests := []struct { name string config any result any expectedErrors []string verify func(t *testing.T, result any) }{ { name: "constructor error propagation", config: struct { MakeA func() (*A, error) }{ MakeA: func() (*A, error) { return nil, sentinel }, }, result: &struct { A *A }{}, expectedErrors: []string{"MakeA", "boom"}, verify: func(t *testing.T, result any) { res := result.(*struct { A *A }) if res.A != nil { t.Fatalf("result A should not be populated on constructor failure") } }, }, { name: "missing dependency", config: struct { MakeB func(*A) (*B, error) }{ MakeB: func(a *A) (*B, error) { return &B{a: a}, nil }, }, result: &struct { B *B }{}, expectedErrors: []string{"*di.A", "di.A"}, verify: func(t *testing.T, result any) { res := result.(*struct { B *B }) if res.B != nil { t.Fatalf("result B should not be populated") } }, }, { name: "cycle detection", config: struct { MakeX func(*Y) *X MakeY func(*X) *Y }{ MakeX: func(y *Y) *X { return &X{} }, MakeY: func(x *X) *Y { return &Y{} }, }, result: &struct { X *X Y *Y }{}, expectedErrors: []string{"MakeX", "MakeY"}, verify: func(t *testing.T, result any) { res := result.(*struct { X *X Y *Y }) if res.X != nil || res.Y != nil { t.Fatalf("cycle should prevent any construction; got X=%v Y=%v", res.X, res.Y) } }, }, { name: "Zero value provider", config: struct { A Provide[*A] B Provide[*B] }{ B: MustProvide[*B](func(*A) *B { return &B{} }), }, result: &struct { B *B }{}, expectedErrors: []string{"Missing Provide[*di.A]"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := Build(tt.config, tt.result) if err == nil { t.Fatalf("expected error") } errorStr := err.Error() var foundError bool for _, expectedErr := range tt.expectedErrors { if strings.Contains(errorStr, expectedErr) { foundError = true break } } if !foundError { t.Fatalf("expected error to contain one of %v, got: %v", tt.expectedErrors, err) } if tt.verify != nil { tt.verify(t, tt.result) } }) } } func TestNewFunction(t *testing.T) { type ANum int tests := []struct { name string config any expected any }{ { name: "struct result", config: struct { A int }{A: 42}, expected: struct { A int }{A: 42}, }, { name: "primitive type result", config: struct{ A int }{A: 42}, expected: 42, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { switch expected := tt.expected.(type) { case struct{ A int }: res, err := New[struct{ A int }](tt.config) if err != nil { t.Fatalf("unexpected error: %v", err) } if res.A != expected.A { t.Fatalf("expected A=%d, got %v", expected.A, res.A) } case int: res, err := New[int](tt.config) if err != nil { t.Fatalf("unexpected error: %v", err) } if res != expected { t.Fatalf("expected %d, got %v", expected, res) } } }) } } func TestBuildPrimitiveTypes(t *testing.T) { tests := []struct { name string config any expected int }{ { name: "build primitive directly", config: struct{ A int }{A: 42}, expected: 42, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var res int err := Build(tt.config, &res) if err != nil { t.Fatalf("unexpected error: %v", err) } if res != tt.expected { t.Fatalf("expected %d, got %v", tt.expected, res) } }) } }