Home > Software design >  Generic type in a switch statement
Generic type in a switch statement

Time:05-03

Just started learning generics. I'm making a command processor and I honestly don't know how to word this so I'm just going to show an example problem:

var ErrInvalidCommand = errors.New("invalid command")

type TransactionalFn[T any] func(ctx context.Context, db T) error

func NewTransactionalCommand[T any](fn TransactionalFn[T]) *TransactionalCommand[T] {
    return &TransactionalCommand[T]{
        fn: fn,
    }
}

type TransactionalCommand[T any] struct {
    fn TransactionalFn[T]
}

func (cmd *TransactionalCommand[T]) StartTransaction() error {
    return nil
}

func (cmd *TransactionalCommand[T]) Commit() error {
    return nil
}

func (cmd *TransactionalCommand[T]) Rollback() error {
    return nil
}

type CMD interface{}

type CommandManager struct{}

func (m *CommandManager) Handle(ctx context.Context, cmd CMD) error {
    switch t := cmd.(type) {
    case *TransactionalCommand[any]:
        return m.handleTransactionalCommand(ctx, t)
    default:
        fmt.Printf("%T\n", cmd)
        return ErrInvalidCommand
    }
}

func (m *CommandManager) handleTransactionalCommand(ctx context.Context, cmd *TransactionalCommand[any]) error {
    if err := cmd.StartTransaction(); err != nil {
        return err
    }

    if err := cmd.fn(ctx, nil); err != nil {
        if err := cmd.Rollback(); err != nil {
            return err
        }
    }

    if err := cmd.Commit(); err != nil {
        return err
    }

    return nil
}

// tests
type db struct{}

func (*db) Do() {
    fmt.Println("doing stuff")
}

func TestCMD(t *testing.T) {
    ctx := context.Background()
    fn := func(ctx context.Context, db *db) error {
        fmt.Println("test cmd")
        db.Do()
        return nil
    }
    tFn := bus.NewTransactionalCommand(fn)

    mng := &bus.CommandManager{}
    err := mng.Handle(ctx, tFn)
    if err != nil {
        t.Fatal(err)
    }
}

mng.handle returns ErrInvalidCommand so the test fails because cmd is *TransactionalCommand[*db] and not *TransactionalCommand[any]

Let me give another, more abstract example:

type A[T any] struct{}

func (*A[T]) DoA() { fmt.Println("do A") }

type B[T any] struct{}

func (*B[T]) DoB() { fmt.Println("do B") }

func Handle(s interface{}) {
    switch x := s.(type) {
    case *A[any]:
        x.DoA()
    case *B[any]:
        x.DoB()
    default:
        fmt.Printf("%T\n", s)
    }
}



func TestFuncSwitch(t *testing.T) {
    i := &A[int]{}

    Handle(i) // expected to print "do A"
}

Why doesn't this switch statement case *A[any] match *A[int]? How to make CommandManager.Handle(...) accept generic Commands?

CodePudding user response:

*A[any] does not match *A[int] because any is a static type, not a wildcard. Therefore instantiating a generic struct with different types yields different types.

In order to correctly match a generic struct in a type switch, you must instantiate it with a type parameter:

func Handle[T any](s interface{}) {
    switch x := any(s).(type) {
    case *A[T]:
        x.DoA()
    case *B[T]:
        x.DoB()
    default:
        panic("no match")
    }
}

Though in absence of other function arguments to infer T, you will have to call Handle with explicit instantiation. T won't be inferred from the struct alone.

func main() {
    i := &A[int]{}
    Handle[int](i) // expected to print "do A"
}

Playground: https://go.dev/play/p/2e5E9LSWPmk


However when Handle is actually a method, as in your database code, this has the drawback of choosing the type parameter when instantiating the receiver.

In order to improve the code here you can make Handle a top-level function:

func Handle[T any](ctx context.Context, cmd CMD) error {
    switch t := cmd.(type) {
    case *TransactionalCommand[T]:
        return handleTransactionalCommand(ctx, t)
    default:
        fmt.Printf("%T\n", cmd)
        return ErrInvalidCommand
    }
}

Then you have the problem of how to supply the argument db T to the command function. For this, you might:

  • simply pass an additional *db argument to Handle and handleTransactionalCommand, which also helps with type parameter inference. Call as Handle(ctx, &db{}, tFn). Playground: https://go.dev/play/p/6WESb86KN5D

  • pass an instance of CommandManager (like solution above but *db is wrapped). Much more verbose, as it requires explicit instantiation everywhere. Playground: https://go.dev/play/p/SpXczsUM5aW

  • use a parametrized interface instead (like below). So you don't even have to type-switch. Playground: https://go.dev/play/p/EgULEIL6AV5

type CMD[T any] interface {
    Exec(ctx context.Context, db T) error
}
  • Related