Python type hint for objects that have “@” (matrix-multiply)

I have a function fun() that accepts a NumPy ArrayLike and a “matrix”, and returns a numpy array.

from numpy.typing import ArrayLike
import numpy as np

def fun(A, x: ArrayLike) -> np.ndarray:
    return (A @ x) ** 2 - 27.0

What’s the correct type for entities that have an @ operation? Note that fun() could also accept a scipy.sparse; perhaps more.

Answer

You can use typing.Protocol to assert that the type implements __matmul__.

class SupportsMatrixMultiplication(typing.Protocol):
    def __matmul__(self, x):
        ...


def fun(A: SupportsMatrixMultiplication, x: ArrayLike) -> np.ndarray:
    return (A @ x) ** 2 - 27.0

You can, I believe, further refine this by providing type hints for x and a return type hint, if you want more than just supporting @ as an operator.