~marcopolo/di

baa3947d3bc1bddb0a034f5054121078962a226a — Marco Munizaga 9 months ago
Initial Version
6 files changed, 947 insertions(+), 0 deletions(-)

A .build.yml
A LICENSE
A README.md
A di.go
A di_test.go
A go.mod
A  => .build.yml +9 -0
@@ 1,9 @@
image: alpine/edge
packages:
  - go
sources:
  - https://git.sr.ht/~marcopolo/di
tasks:
  - test: |
      cd di
      go test ./...

A  => LICENSE +11 -0
@@ 1,11 @@
Copyright 2025 Marco Munizaga

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

A  => README.md +9 -0
@@ 1,9 @@
# di

A small straightforward DI library for Go.

## Usage

## Examples

## Why another DI library?

A  => di.go +553 -0
@@ 1,553 @@
// Package di is a minimal reflection-based dependency injection helper.
//
// Usage pattern:
//
//	type Config struct {
//	    Logger di.Provide[*slog.Logger]
//	    Server func(logger *slog.Logger) (*http.Server, error)
//	}
//
//	type Result struct {
//	    Server *http.Server
//	}
//
//	cfg := Config{ /* constructors here ... */ }
//	var res Result
//	if err := di.Build(cfg, &res); err != nil { ... }\
//
// Or with New(cfg):
//
// server, err := di.New[*http.Server](cfg)
//
// See the doc comments for Build and New for more details
package di

import (
	"errors"
	"fmt"
	"reflect"
	"strings"
)

type Provide[Out any] struct {
	fOrV any
}

// SideEffect is a sentinel value representing a constructor used only for side
// effects such as introducing two components together without a circular
// dependency.
//
// It's idiomatic to have your config declare a SideEffects field of type
// []di.MustProvide[di.SideEffect], and have a "_" field in your result struct
// of type []di.SideEffect.
type SideEffect struct{}

func NewSideEffect(f any) (Provide[SideEffect], error) {
	return NewProvide[SideEffect](f)
}

func MustSideEffect(f any) Provide[SideEffect] {
	return Must(NewSideEffect(f))
}

type provideI interface {
	diOutType() reflect.Type
	diPayload() any
}

func (p Provide[Out]) diOutType() reflect.Type { return typeOf[Out]() }
func (p Provide[Out]) diPayload() any          { return p.fOrV }

func Must[T any](t T, err error) T {
	if err != nil {
		panic(err)
	}
	return t
}

func MustProvide[Out any](ctorOrVal any) Provide[Out] {
	return Must(NewProvide[Out](ctorOrVal))
}

func NewProvide[Out any](ctorOrVal any) (Provide[Out], error) {
	outType := typeOf[Out]()

	if ctorOrVal == nil {
		if canBeNil(outType) {
			return Provide[Out]{fOrV: nil}, nil
		}
		return Provide[Out]{}, fmt.Errorf("Provide[%v]: nil not valid for non-nilable type", outType)
	}

	t := reflect.TypeOf(ctorOrVal)

	// Case 1: function returning Out or (Out, error). Arbitrary args allowed.
	if t.Kind() == reflect.Func {
		nout := t.NumOut()
		switch nout {
		case 1:
			if !t.Out(0).AssignableTo(outType) {
				return Provide[Out]{}, fmt.Errorf("Provide[%v]: function return %v is not assignable to %v",
					outType, t.Out(0), outType)
			}
			return Provide[Out]{fOrV: ctorOrVal}, nil

		case 2:
			if !t.Out(0).AssignableTo(outType) {
				return Provide[Out]{}, fmt.Errorf("Provide[%v]: first return %v is not assignable to %v",
					outType, t.Out(0), outType)
			}
			if !isErrorType(t.Out(1)) {
				return Provide[Out]{}, fmt.Errorf("Provide[%v]: second return must be error, got %v",
					outType, t.Out(1))
			}
			return Provide[Out]{fOrV: ctorOrVal}, nil

		default:
			return Provide[Out]{}, fmt.Errorf("Provide[%v]: function must return Out or (Out, error); got %d returns",
				outType, nout)
		}
	}

	// Case 2: value assignable to Out (covers interface satisfaction)
	if !t.AssignableTo(outType) {
		return Provide[Out]{}, fmt.Errorf("Provide[%v]: value of type %v is not assignable to %v", outType, t, outType)
	}
	return Provide[Out]{fOrV: ctorOrVal}, nil
}

// Helpers
func typeOf[T any]() reflect.Type {
	return reflect.TypeOf((*T)(nil)).Elem()
}

func canBeNil(t reflect.Type) bool {
	switch t.Kind() {
	case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Slice:
		return true
	default:
		return false
	}
}

type ctor struct {
	name string
	fn   reflect.Value
	out  reflect.Type
}

type collector struct {
	// Singular instances & providers
	values    map[reflect.Type]reflect.Value // exact type -> instance
	providers map[reflect.Type][]ctor        // out type -> ctors

	// List providers for []T
	listValues    map[reflect.Type][]reflect.Value // elem dynamic type -> instances
	listProviders map[reflect.Type][]ctor          // elem out type -> ctors
	listPresence  map[reflect.Type]bool            // elem T present explicitly (even if empty)

	// Cycle detection
	resolving map[reflect.Type]bool
}

func New[R any, C any](config C) (R, error) {
	var r R
	err := Build(config, &r)
	return r, err
}

// Build resolves only what's needed to populate exported fields in result.
// Supports arbitrarily nested provider namespaced structs inside config.
//
// # TODO rewrite this doc
//
// Supported providers in config (at any nesting depth):
//   - Provide[T]                    // single constructor or value for T
//   - []Provide[T]                  // list of constructors/values contributing to []T
//   - func(...Deps) T / (T,error)   // singular constructor for T
//   - value of type T               // preprovided singular value
//   - []func(...Deps) T/(T,error)   // contributes to []T
//   - []T                           // contributes to []T
//
// Non-func, non-zero exported fields remain prebound instances.
// result must be a pointer to a struct or value. If it is a value the config
// must define how to construct it.
func Build[C any, R any](config C, result R) error {
	cfgV := reflect.ValueOf(config)
	for cfgV.IsValid() && cfgV.Kind() == reflect.Pointer {
		if cfgV.IsNil() {
			return errors.New("config pointer is nil")
		}
		cfgV = cfgV.Elem()
	}
	if !cfgV.IsValid() || cfgV.Kind() != reflect.Struct {
		return errors.New("config must be a struct or pointer to struct")
	}

	resV := reflect.ValueOf(result)
	if !resV.IsValid() || resV.Kind() != reflect.Pointer {
		return errors.New("result must be a pointer to struct or a pointer to a value")
	}

	c := &collector{
		values:        make(map[reflect.Type]reflect.Value),
		providers:     make(map[reflect.Type][]ctor),
		listValues:    make(map[reflect.Type][]reflect.Value),
		listProviders: make(map[reflect.Type][]ctor),
		listPresence:  make(map[reflect.Type]bool),
		resolving:     make(map[reflect.Type]bool),
	}

	// Provide access to the Config value itself
	c.values[cfgV.Type()] = cfgV
	if err := c.collect(cfgV, ""); err != nil {
		return err
	}

	// Can we just resolve the direct result type?
	if v, err := c.resolve(resV.Elem().Type()); err == nil {
		resV.Elem().Set(v)
		return nil
	}

	if resV.Elem().Kind() != reflect.Struct {
		return fmt.Errorf("couldn't build result direct, and can not fill result as it is not a struct")
	}

	// Populate result fields of the struct
	var missing []string
	resStruct := resV.Elem()
	resT := resStruct.Type()
	for i := 0; i < resT.NumField(); i++ {
		sf := resT.Field(i)
		if sf.Name != "_" && sf.PkgPath != "" {
			continue
		}
		fv := resStruct.Field(i)
		if !fv.IsZero() {
			continue
		}
		v, err := c.resolve(sf.Type)
		if err != nil {
			missing = append(missing, fmt.Sprintf("%s (%s): %v", sf.Name, sf.Type, err))
			continue
		}

		// Evaluate, but do not set underscore field names
		if sf.Name != "_" {
			fv.Set(v)
		}
	}
	if len(missing) > 0 {
		return fmt.Errorf("failed to build result fields:\n  - %s", strings.Join(missing, "\n  - "))
	}
	return nil
}

func (c *collector) collect(v reflect.Value, path string) error {
	if !v.IsValid() {
		return nil
	}
	// Deref pointers
	for v.Kind() == reflect.Pointer {
		if v.IsNil() {
			return nil
		}
		v = v.Elem()
	}
	if v.Kind() != reflect.Struct {
		return nil
	}

	t := v.Type()
	for i := 0; i < t.NumField(); i++ {
		sf := t.Field(i)
		if sf.PkgPath != "" { // unexported
			continue
		}
		fv := v.Field(i)
		name := sf.Name
		if path != "" {
			name = path + "." + sf.Name
		}

		// Provide[T] (singular) must be recognized BEFORE treating structs as namespaces.
		if pi, ok := asProvide(fv); ok {
			outT := pi.diOutType()
			payload := pi.diPayload()
			if payload == nil {
				c.values[outT] = reflect.Zero(outT)
				continue
			}
			pt := reflect.TypeOf(payload)
			if pt.Kind() == reflect.Func {
				if err := validateCtorSignature(pt, name); err != nil {
					return err
				}
				c.providers[pt.Out(0)] = append(c.providers[pt.Out(0)], ctor{name: name, fn: reflect.ValueOf(payload), out: pt.Out(0)})
			} else {
				c.values[pt] = reflect.ValueOf(payload)
			}
			continue
		}

		// Namespace recursion for embedded structs
		if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
			// Provide access to the nested config value itself
			c.values[sf.Type] = fv

			if err := c.collect(fv, name); err != nil {
				return err
			}
			continue
		}

		// []Provide[T] (list)
		if fv.Kind() == reflect.Slice && fv.Type().Elem().Kind() == reflect.Struct {
			if provideElem, ok := reflect.New(fv.Type().Elem()).Elem().Interface().(provideI); ok {
				if fv.Len() == 0 && !fv.IsNil() {
					// Set presence of empty list
					outT := provideElem.diOutType()
					c.listPresence[outT] = true
				}

				for j := 0; j < fv.Len(); j++ {
					pi := fv.Index(j).Interface().(provideI)
					outT := pi.diOutType()
					c.listPresence[outT] = true
					payload := pi.diPayload()
					if payload == nil {
						c.listValues[outT] = append(c.listValues[outT], reflect.Zero(outT))
						continue
					}
					pt := reflect.TypeOf(payload)
					if pt.Kind() == reflect.Func {
						if err := validateCtorSignature(pt, fmt.Sprintf("%s[%d]", name, j)); err != nil {
							return err
						}
						c.listProviders[pt.Out(0)] = append(c.listProviders[pt.Out(0)], ctor{
							name: fmt.Sprintf("%s[%d]", name, j),
							fn:   reflect.ValueOf(payload),
							out:  pt.Out(0),
						})
					} else {
						c.listValues[pt] = append(c.listValues[pt], reflect.ValueOf(payload))
					}
				}
				continue
			}
		}

		// Fallback to earlier forms (singular func/value; list of funcs/values)
		switch sf.Type.Kind() {
		case reflect.Func:
			ft := fv.Type()
			if err := validateCtorSignature(ft, name); err != nil {
				return err
			}
			c.providers[ft.Out(0)] = append(c.providers[ft.Out(0)], ctor{name: name, fn: fv, out: ft.Out(0)})

		case reflect.Slice:
			elemT := sf.Type.Elem()
			if elemT.Kind() == reflect.Func {
				for j := 0; j < fv.Len(); j++ {
					fn := fv.Index(j)
					ft := fn.Type()
					if err := validateCtorSignature(ft, fmt.Sprintf("%s[%d]", name, j)); err != nil {
						return err
					}
					c.listProviders[ft.Out(0)] = append(c.listProviders[ft.Out(0)], ctor{
						name: fmt.Sprintf("%s[%d]", name, j),
						fn:   fn, out: ft.Out(0),
					})
				}
			} else {
				if fv.Len() == 0 {
					c.listPresence[elemT] = true // explicit empty list present
				}
				for j := 0; j < fv.Len(); j++ {
					vj := fv.Index(j)
					c.listValues[vj.Type()] = append(c.listValues[vj.Type()], vj)
				}
			}

		default:
			// preprovided singular instance (non-zero only)
			if !fv.IsZero() {
				c.values[fv.Type()] = fv
			}
		}
	}
	return nil
}

func (c *collector) resolve(t reflect.Type) (reflect.Value, error) {
	// Cached exact?
	if v, ok := c.values[t]; ok {
		return v, nil
	}

	// Slice resolution []T
	if t.Kind() == reflect.Slice {
		elem := t.Elem()
		var elems []reflect.Value
		found := false

		// Explicit list values (concrete types assignable to elem)
		for haveT, vals := range c.listValues {
			if isAssignableOrImpl(haveT, elem) {
				found = true
				elems = append(elems, vals...)
			}
		}
		// From list constructors whose out is assignable to elem
		for outT, ctors := range c.listProviders {
			if !isAssignableOrImpl(outT, elem) {
				continue
			}
			found = true
			for _, ctor := range ctors {
				if c.resolving[t] {
					return reflect.Value{}, fmt.Errorf("dependency cycle detected at %s", t)
				}
				c.resolving[t] = true
				ft := ctor.fn.Type()
				args := make([]reflect.Value, ft.NumIn())
				for i := 0; i < ft.NumIn(); i++ {
					paramT := ft.In(i)
					arg, err := c.resolve(paramT)
					if err != nil {
						delete(c.resolving, t)
						return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", ctor.name, paramT, err)
					}
					if !isAssignableOrImpl(arg.Type(), paramT) {
						delete(c.resolving, t)
						return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", ctor.name, arg.Type(), paramT)
					}
					args[i] = arg
				}
				outs := ctor.fn.Call(args)
				delete(c.resolving, t)
				if len(outs) == 2 && !outs[1].IsNil() {
					return reflect.Value{}, fmt.Errorf("%s error: %w", ctor.name, outs[1].Interface().(error))
				}
				elems = append(elems, outs[0])
			}
		}

		// If an explicit provider for elem T exists (e.g., []Provide[T] or []T present but empty),
		// we should still succeed with an empty slice.
		if !found {
			if c.listPresence[elem] {
				c.values[t] = reflect.MakeSlice(t, 0, 0)
				return c.values[t], nil
			}
			return reflect.Value{}, fmt.Errorf("no provider for %s", t)
		}

		slice := reflect.MakeSlice(t, 0, len(elems))
		for _, e := range elems {
			slice = reflect.Append(slice, e)
		}
		c.values[t] = slice
		return slice, nil
	}

	// Try existing instances for interface targets (singular)
	if t.Kind() == reflect.Interface {
		for haveT, v := range c.values {
			if haveT.Implements(t) {
				return v, nil
			}
		}
	}

	// Cycle detection (singular)
	if c.resolving[t] {
		return reflect.Value{}, fmt.Errorf("dependency cycle detected at %s", t)
	}
	c.resolving[t] = true
	defer delete(c.resolving, t)

	// Pick singular provider(s)
	var candidates []ctor
	if ps, ok := c.providers[t]; ok {
		candidates = append(candidates, ps...)
	} else if t.Kind() == reflect.Interface {
		for outT, ps := range c.providers {
			if outT.Implements(t) {
				candidates = append(candidates, ps...)
			}
		}
	}

	if len(candidates) == 0 {
		return reflect.Value{}, fmt.Errorf("no provider for %s", t)
	}
	if len(candidates) > 1 {
		var names []string
		for _, candidate := range candidates {
			names = append(names, candidate.name+" -> "+candidate.out.String())
		}
		return reflect.Value{}, fmt.Errorf("ambiguous providers for %s: %s", t, strings.Join(names, ", "))
	}

	impl := candidates[0]
	ft := impl.fn.Type()
	args := make([]reflect.Value, ft.NumIn())
	for i := 0; i < ft.NumIn(); i++ {
		paramT := ft.In(i)
		arg, err := c.resolve(paramT)
		if err != nil {
			return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", impl.name, paramT, err)
		}
		if !isAssignableOrImpl(arg.Type(), paramT) {
			return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", impl.name, arg.Type(), paramT)
		}
		args[i] = arg
	}
	outs := impl.fn.Call(args)
	if len(outs) == 2 {
		if !outs[1].IsNil() {
			return reflect.Value{}, fmt.Errorf("%s error: %w", impl.name, outs[1].Interface().(error))
		}
		c.values[impl.out] = outs[0]
		return outs[0], nil
	}
	c.values[impl.out] = outs[0]
	return outs[0], nil
}

// --- helpers ---

func validateCtorSignature(ft reflect.Type, name string) error {
	if ft.IsVariadic() {
		return fmt.Errorf("constructor %q: variadics not supported", name)
	}
	nout := ft.NumOut()
	if nout == 1 {
		return nil
	}
	if nout == 2 && isErrorType(ft.Out(1)) {
		return nil
	}
	return fmt.Errorf("constructor %q: must return (T) or (T, error); got %d returns", name, nout)
}

func isAssignableOrImpl(have, want reflect.Type) bool {
	return have.AssignableTo(want) || (want.Kind() == reflect.Interface && have.Implements(want))
}

func isErrorType(t reflect.Type) bool {
	return t == reflect.TypeOf((*error)(nil)).Elem()
}

// asProvide tries to view v as a Provide[*]. Returns (iface, true) if so.
func asProvide(v reflect.Value) (provideI, bool) {
	if !v.IsValid() {
		return nil, false
	}
	x := v.Interface()
	pi, ok := x.(provideI)
	return pi, ok
}

A  => di_test.go +362 -0
@@ 1,362 @@
package di

import (
	"errors"
	"strings"
	"testing"
)

func TestBuildSuccess(t *testing.T) {
	type A struct {
		val string
	}
	type C struct {
		val int
	}
	type B struct {
		a  *A
		cs []C
	}

	type Config struct {
		MakeA  Provide[*A]
		MakeB  Provide[*B]
		MakeCs []Provide[C]
	}

	cfg := Config{
		MakeA: MustProvide[*A](func() (*A, error) {
			return &A{val: "hello"}, nil
		}),
		MakeB: MustProvide[*B](func(a *A, cs []C) *B {
			return &B{a: a, cs: cs}
		}),
		MakeCs: []Provide[C]{
			MustProvide[C](C{val: 1}),
			MustProvide[C](func() (C, error) {
				return C{val: 2}, nil
			})},
	}

	type Result struct {
		A *A
		B *B
	}
	var res Result
	err := Build(cfg, &res)
	if err != nil {
		t.Fatalf("Build failed: %v", err)
	}
	if res.A == nil {
		t.Fatalf("expected res.A to be populated")
	}
	if res.B == nil {
		t.Fatalf("expected res.B to be populated")
	}
	if res.B.a != res.A {
		t.Fatalf("expected B.a to reference A instance")
	}
	if len(res.B.cs) != 2 {
		t.Fatalf("wrong count. Saw %d", len(res.B.cs))
	}
	if res.B.cs[0].val != 1 {
		t.Fatalf("wrong value")
	}
	if res.B.cs[1].val != 2 {
		t.Fatalf("wrong value")
	}
	if res.A.val != "hello" {
		t.Fatalf("unexpected A value: %s", res.A.val)
	}
}

func TestBuildSuccess2(t *testing.T) {
	type A struct {
		val string
	}
	type B struct {
		a *A
	}

	type Config struct {
		MakeA func() (*A, error)
		MakeB func(*A) (*B, error)
	}

	type Result struct {
		A *A
	}

	cfg := Config{
		MakeA: func() (*A, error) {
			return &A{val: "hello"}, nil
		},
		MakeB: func(a *A) (*B, error) {
			panic("Unexpected call to MakeB")
			// (removed unreachable code after panic)
		},
	}

	var res Result
	err := Build(cfg, &res)
	if err != nil {
		t.Fatalf("Build failed: %v", err)
	}
	if res.A == nil {
		t.Fatalf("expected res.A to be populated")
	}
	if res.A.val != "hello" {
		t.Fatalf("unexpected A value: %s", res.A.val)
	}
}

// Test that constructor error is propagated.
func TestBuildConstructorError(t *testing.T) {
	type A struct{}
	sentinel := errors.New("boom")

	type Config struct {
		MakeA func() (*A, error)
	}
	type Result struct {
		A *A
	}

	cfg := Config{
		MakeA: func() (*A, error) {
			return nil, sentinel
		},
	}
	var res Result
	err := Build(cfg, &res)
	if err == nil {
		t.Fatalf("expected error")
	}
	if !strings.Contains(err.Error(), "MakeA") {
		t.Fatalf("expected error to mention constructor name, got: %v", err)
	}
	if !strings.Contains(err.Error(), "boom") {
		t.Fatalf("expected original error message, got: %v", err)
	}
	if res.A != nil {
		t.Fatalf("result A should not be populated on constructor failure")
	}
}

// Test missing dependency (constructor requires *A but *A not provided).
func TestBuildMissingDependency(t *testing.T) {
	type A struct{}
	type B struct {
		a *A
	}

	type Config struct {
		MakeB func(*A) (*B, error)
	}
	type Result struct {
		B *B
	}

	cfg := Config{
		MakeB: func(a *A) (*B, error) {
			return &B{a: a}, nil
		},
	}
	var res Result
	err := Build(cfg, &res)
	if err == nil {
		t.Fatalf("expected missing dependency error")
	}
	// Parameter type string should appear (may be *di.A).
	if !strings.Contains(err.Error(), "*di.A") && !strings.Contains(err.Error(), "di.A") {
		t.Fatalf("expected error to mention missing type *di.A, got: %v", err)
	}
	if res.B != nil {
		t.Fatalf("result B should not be populated")
	}
}

// Test cycle detection between X and Y.
func TestBuildCycleDetection(t *testing.T) {
	type X struct{}
	type Y struct{}

	type Config struct {
		MakeX func(*Y) *X
		MakeY func(*X) *Y
	}
	type Result struct {
		X *X
		Y *Y
	}

	cfg := Config{
		MakeX: func(y *Y) *X {
			return &X{}
		},
		MakeY: func(x *X) *Y {
			return &Y{}
		},
	}

	var res Result
	err := Build(cfg, &res)
	if err == nil {
		t.Fatalf("expected cycle detection error")
	}
	// Both constructors should still be listed as remaining.
	if !strings.Contains(err.Error(), "MakeX") || !strings.Contains(err.Error(), "MakeY") {
		t.Fatalf("expected error to list remaining constructors MakeX and MakeY, got: %v", err)
	}
	if res.X != nil || res.Y != nil {
		t.Fatalf("cycle should prevent any construction; got X=%v Y=%v", res.X, res.Y)
	}
}

// Ensure that providing pre-supplied value satisfies dependency without constructor for it.
func TestBuildWithPreSuppliedValue(t *testing.T) {
	type A struct {
		v int
	}
	type B struct {
		a *A
	}

	type Config struct {
		// Only constructor for B; A provided directly in config.
		A  *A
		MB func(*A) (*B, error)
	}
	type Result struct {
		A *A
		B *B
	}

	cfg := Config{
		A: &A{v: 42},
		MB: func(a *A) (*B, error) {
			return &B{a: a}, nil
		},
	}

	var res Result
	err := Build(cfg, &res)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	if res.A == nil || res.A.v != 42 {
		t.Fatalf("expected pre-supplied A (42), got %+v", res.A)
	}
	if res.B == nil || res.B.a != res.A {
		t.Fatalf("expected B referencing A, got %+v", res.B)
	}
}

func TestTypeAlias(t *testing.T) {
	type ANum int
	type Config struct {
		A ANum
		B int
	}

	type Result struct {
		A ANum
	}
	var res Result
	err := Build(Config{3, 4}, &res)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	if res.A != 3 {
		t.Fatalf("expected A=3, got %v", res.A)
	}
}

func TestReferenceConfig(t *testing.T) {
	type NestedConfig struct {
		OtherSetting   bool
		NestedDecision func(c NestedConfig) uint
	}

	type Config struct {
		NestedConfig
		SomeSetting bool
		Inner       func(c Config) int
	}

	type Result struct {
		A int
		B uint
	}
	var res Result
	err := Build(Config{
		SomeSetting: true,
		Inner: func(c Config) int {
			if c.SomeSetting {
				return 1
			}
			return 0
		},
		NestedConfig: NestedConfig{
			OtherSetting: true,
			NestedDecision: func(c NestedConfig) uint {
				if c.OtherSetting {
					return 1
				}
				return 0
			},
		},
	}, &res)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	if res.A != 1 {
		t.Fatalf("expected A=1, got %v", res.A)
	}
	if res.B != 1 {
		t.Fatalf("expected B=1, got %v", res.A)
	}
}

func TestNew(t *testing.T) {
	type Config struct {
		A int
	}

	type Result struct {
		A int
	}

	cfg := Config{A: 42}
	res, err := New[Result](cfg)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	if res.A != 42 {
		t.Fatalf("expected A=42, got %v", res.A)
	}
}

func TestSpecificTypes(t *testing.T) {
	type Config struct {
		A int
	}

	cfg := Config{A: 42}
	res, err := New[int](cfg)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	if res != 42 {
		t.Fatalf("expected A=42, got %v", res)
	}

	var res2 int
	err = Build(Config{A: 42}, &res2)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	if res2 != 42 {
		t.Fatalf("expected A=42, got %v", res2)
	}
}

A  => go.mod +3 -0
@@ 1,3 @@
module git.sr.ht/~marcopolo/di

go 1.24.0