package di
import (
"errors"
"fmt"
"net/http"
"strings"
"testing"
)
func ExampleBuild() {
type Username string
type Config struct {
User Username
Age int
Greeting func(Username, int) string
}
cfg := Config{
User: "Alice",
Age: 42,
Greeting: func(u Username, age int) string {
return fmt.Sprintf("Hello, %s. You've been around the sun %d times!", string(u), age)
},
}
type Result struct {
Greeting string
}
var res Result
err := Build(cfg, &res)
if err != nil {
fmt.Println(err)
return
}
fmt.Println(res.Greeting)
// Output: Hello, Alice. You've been around the sun 42 times!
}
func ExampleNew() {
type Username string
type Config struct {
User Username
Age int
Greeting func(Username, int) string
}
cfg := Config{
User: "Alice",
Age: 42,
Greeting: func(u Username, age int) string {
return fmt.Sprintf("Hello, %s. You've been around the sun %d times!", string(u), age)
},
}
greeting, err := New[string](cfg)
if err != nil {
fmt.Println(err)
return
}
fmt.Println(greeting)
// Output: Hello, Alice. You've been around the sun 42 times!
}
func ExampleSideEffect() {
type Config struct {
Server *http.Server
SideEffects []Provide[SideEffect]
}
type Result struct {
StartedServer *http.Server
_ []SideEffect
}
res, err := New[Result](&Config{
Server: &http.Server{
Addr: ":8080",
},
SideEffects: []Provide[SideEffect]{
MustProvide[SideEffect](func() SideEffect {
fmt.Println("Starting server...")
go http.ListenAndServe(":8080", nil)
return SideEffect{}
}),
},
})
if err != nil {
fmt.Println(err)
return
}
defer res.StartedServer.Close()
// Output: Starting server...
}
func TestBuildSuccess(t *testing.T) {
type A struct {
val string
}
type C struct {
val int
}
type B struct {
a *A
cs []C
}
type NestedConfig struct {
OtherSetting bool
NestedDecision func(c NestedConfig) uint
}
type ANum int
type ConfigWithInner struct {
NestedConfig
SomeSetting bool
Inner func(c ConfigWithInner) int
}
tests := []struct {
name string
config interface{}
result interface{}
verify func(t *testing.T, result interface{})
}{
{
name: "complex dependencies with providers",
config: struct {
MakeA Provide[*A]
MakeB Provide[*B]
MakeCs []Provide[C]
}{
MakeA: MustProvide[*A](func() (*A, error) {
return &A{val: "hello"}, nil
}),
MakeB: MustProvide[*B](func(a *A, cs []C) *B {
return &B{a: a, cs: cs}
}),
MakeCs: []Provide[C]{
MustProvide[C](C{val: 1}),
MustProvide[C](func() (C, error) {
return C{val: 2}, nil
}),
},
},
result: &struct {
A *A
B *B
}{},
verify: func(t *testing.T, result interface{}) {
res := result.(*struct {
A *A
B *B
})
if res.A == nil {
t.Fatalf("expected res.A to be populated")
}
if res.B == nil {
t.Fatalf("expected res.B to be populated")
}
if res.B.a != res.A {
t.Fatalf("expected B.a to reference A instance")
}
if len(res.B.cs) != 2 {
t.Fatalf("wrong count. Saw %d", len(res.B.cs))
}
if res.B.cs[0].val != 1 {
t.Fatalf("wrong value")
}
if res.B.cs[1].val != 2 {
t.Fatalf("wrong value")
}
if res.A.val != "hello" {
t.Fatalf("unexpected A value: %s", res.A.val)
}
},
},
{
name: "simple function constructors",
config: struct {
MakeA func() (*A, error)
MakeB func(*A) (*B, error)
}{
MakeA: func() (*A, error) {
return &A{val: "hello"}, nil
},
MakeB: func(a *A) (*B, error) {
panic("Unexpected call to MakeB")
},
},
result: &struct {
A *A
}{},
verify: func(t *testing.T, result interface{}) {
res := result.(*struct {
A *A
})
if res.A == nil {
t.Fatalf("expected res.A to be populated")
}
if res.A.val != "hello" {
t.Fatalf("unexpected A value: %s", res.A.val)
}
},
},
{
name: "pre-supplied values",
config: struct {
A *A
MB func(*A) (*B, error)
}{
A: &A{val: "pre-supplied"},
MB: func(a *A) (*B, error) {
return &B{a: a}, nil
},
},
result: &struct {
A *A
B *B
}{},
verify: func(t *testing.T, result interface{}) {
res := result.(*struct {
A *A
B *B
})
if res.A == nil || res.A.val != "pre-supplied" {
t.Fatalf("expected pre-supplied A, got %+v", res.A)
}
if res.B == nil || res.B.a != res.A {
t.Fatalf("expected B referencing A, got %+v", res.B)
}
},
},
{
name: "type aliases",
config: struct {
A ANum
B int
}{A: 3, B: 4},
result: &struct {
A ANum
}{},
verify: func(t *testing.T, result interface{}) {
res := result.(*struct {
A ANum
})
if res.A != 3 {
t.Fatalf("expected A=3, got %v", res.A)
}
},
},
{
name: "reference config in constructors",
config: ConfigWithInner{
SomeSetting: true,
Inner: func(c ConfigWithInner) int {
if c.SomeSetting {
return 1
}
return 0
},
NestedConfig: NestedConfig{
OtherSetting: true,
NestedDecision: func(c NestedConfig) uint {
if c.OtherSetting {
return 1
}
return 0
},
},
},
result: &struct {
A int
B uint
}{},
verify: func(t *testing.T, result interface{}) {
res := result.(*struct {
A int
B uint
})
if res.A != 1 {
t.Fatalf("expected A=1, got %v", res.A)
}
if res.B != 1 {
t.Fatalf("expected B=1, got %v", res.B)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := Build(tt.config, tt.result)
if err != nil {
t.Fatalf("Build failed: %v", err)
}
tt.verify(t, tt.result)
})
}
}
func TestBuildErrors(t *testing.T) {
type A struct{}
type B struct {
a *A
}
type X struct{}
type Y struct{}
sentinel := errors.New("boom")
tests := []struct {
name string
config interface{}
result interface{}
expectedErrors []string
verify func(t *testing.T, result interface{})
}{
{
name: "constructor error propagation",
config: struct {
MakeA func() (*A, error)
}{
MakeA: func() (*A, error) {
return nil, sentinel
},
},
result: &struct {
A *A
}{},
expectedErrors: []string{"MakeA", "boom"},
verify: func(t *testing.T, result interface{}) {
res := result.(*struct {
A *A
})
if res.A != nil {
t.Fatalf("result A should not be populated on constructor failure")
}
},
},
{
name: "missing dependency",
config: struct {
MakeB func(*A) (*B, error)
}{
MakeB: func(a *A) (*B, error) {
return &B{a: a}, nil
},
},
result: &struct {
B *B
}{},
expectedErrors: []string{"*di.A", "di.A"},
verify: func(t *testing.T, result interface{}) {
res := result.(*struct {
B *B
})
if res.B != nil {
t.Fatalf("result B should not be populated")
}
},
},
{
name: "cycle detection",
config: struct {
MakeX func(*Y) *X
MakeY func(*X) *Y
}{
MakeX: func(y *Y) *X {
return &X{}
},
MakeY: func(x *X) *Y {
return &Y{}
},
},
result: &struct {
X *X
Y *Y
}{},
expectedErrors: []string{"MakeX", "MakeY"},
verify: func(t *testing.T, result interface{}) {
res := result.(*struct {
X *X
Y *Y
})
if res.X != nil || res.Y != nil {
t.Fatalf("cycle should prevent any construction; got X=%v Y=%v", res.X, res.Y)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := Build(tt.config, tt.result)
if err == nil {
t.Fatalf("expected error")
}
errorStr := err.Error()
var foundError bool
for _, expectedErr := range tt.expectedErrors {
if strings.Contains(errorStr, expectedErr) {
foundError = true
break
}
}
if !foundError {
t.Fatalf("expected error to contain one of %v, got: %v", tt.expectedErrors, err)
}
tt.verify(t, tt.result)
})
}
}
func TestNewFunction(t *testing.T) {
type ANum int
tests := []struct {
name string
config interface{}
expected interface{}
}{
{
name: "struct result",
config: struct {
A int
}{A: 42},
expected: struct {
A int
}{A: 42},
},
{
name: "primitive type result",
config: struct{ A int }{A: 42},
expected: 42,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
switch expected := tt.expected.(type) {
case struct{ A int }:
res, err := New[struct{ A int }](tt.config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res.A != expected.A {
t.Fatalf("expected A=%d, got %v", expected.A, res.A)
}
case int:
res, err := New[int](tt.config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res != expected {
t.Fatalf("expected %d, got %v", expected, res)
}
}
})
}
}
func TestBuildPrimitiveTypes(t *testing.T) {
tests := []struct {
name string
config interface{}
expected int
}{
{
name: "build primitive directly",
config: struct{ A int }{A: 42},
expected: 42,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var res int
err := Build(tt.config, &res)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res != tt.expected {
t.Fatalf("expected %d, got %v", tt.expected, res)
}
})
}
}