我可以使用像enable_if这样的隐式转换运算符吗?

发布于 2024-12-09 03:59:46 字数 5248 浏览 0 评论 0原文

我有一个(基本上完成的)矩阵类(在这篇文章的后面)。如果矩阵是 1x1 矩阵,那么我希望隐式转换为支持类型(例如 1x1 浮点矩阵应转换为浮点)。

有没有一种方法可以做到这一点,而无需创建专门化并复制 Matrix 中的所有方法? (例如使用 std::enable_if 之类的东西?)我基本上想当且仅当 ROWS == COLS == 1 时启用隐式转换。

template <std::size_t ROWS, std::size_t COLS = 1, typename BackingType = float>
class Matrix
{
    BackingType data[ROWS][COLS];
public:
    Matrix()
    {
        for(std::size_t rdx = 0; rdx < ROWS; ++rdx)
        {
            for (std::size_t cdx = 0; cdx < COLS; ++cdx)
            {
                data[rdx][cdx] = 0;
            }
        }
    }
    const BackingType& Member(std::size_t index) const
    {
        assert(index < ROWS*COLS);
        return *(static_cast<BackingType*>(&data[0][0]) + index);
    }
    BackingType& Member(std::size_t index)
    {
        assert(index < ROWS*COLS);
        return *(static_cast<BackingType*>(&data[0][0]) + index);
    }
    const BackingType& Member(std::size_t rowIndex, std::size_t colIndex) const
    {
        assert(rowIndex < ROWS);
        assert(colIndex < COLS);
        return data[rowIndex][colIndex];
    }
    BackingType& Member(std::size_t rowIndex, std::size_t colIndex)
    {
        assert(rowIndex < ROWS);
        assert(colIndex < COLS);
        return data[rowIndex][colIndex];
    }
    Matrix<COLS, ROWS, BackingType> Transpose() const
    {
        Matrix<COLS, ROWS, BackingType> result;
        for(std::size_t rowIdx = 0; rowIdx < ROWS; rowIdx++)
        {
            for (std::size_t colIdx = 0; colIdx < COLS; ++colIdx)
            {
                result.Member(colIdx, rowIdx) = Member(rowIdx, colIdx);
            }
        }
        return result;
    }
    template <std::size_t otherRows, std::size_t otherCols>
    Matrix<ROWS + otherRows, COLS, BackingType> AugmentBelow(const Matrix<otherRows, otherCols, BackingType>& other)
    {
        static_assert(COLS == otherCols, "Columns must match for a vertical augmentation.");
        Matrix<ROWS + otherRows, COLS, BackingType> result;
        for (std::size_t curRow = 0; curRow < ROWS; ++curRow)
        {
            for (std::size_t curCol = 0; curCol < COLS; ++curCol)
            {
                result.Member(curRow, curCol) = Member(curRow, curCol);
            }
        }
        for (std::size_t curRow = ROWS; curRow < (ROWS + otherRows); ++curRow)
        {
            for (std::size_t curCol = 0; curCol < COLS; ++curCol)
            {
                result.Member(curRow, curCol) = other.Member(curRow - ROWS, curCol);
            }
        }
        return result;
    }
    template <std::size_t otherRows, std::size_t otherCols>
    Matrix<ROWS, COLS + otherCols, BackingType> AugmentRight(const Matrix<otherRows, otherCols, BackingType>& other)
    {
        static_assert(ROWS == otherRows, "Rows must match for a vertical augmentation.");
        Matrix<ROWS, COLS + otherCols, BackingType> result;
        for (std::size_t curRow = 0; curRow < ROWS; ++curRow)
        {
            for (std::size_t curCol = 0; curCol < COLS; ++curCol)
            {
                result.Member(curRow, curCol) = Member(curRow, curCol);
            }
            for (std::size_t curCol = COLS; curCol < (COLS + otherCols); ++curCol)
            {
                result.Member(curRow, curCol) = other.Member(curRow, curCol - COLS);
            }
        }
        return result;
    }
    static Matrix<ROWS, COLS, BackingType> Identity()
    {
        static_assert(ROWS == COLS, "Identity matrices are always square.");
        Matrix<ROWS, COLS, BackingType> result;
        for (std::size_t diagonal = 0; diagonal < ROWS; ++diagonal)
        {
            result.Member(diagonal, diagonal) = 1;
        }
        return result;
    }
};

template <std::size_t leftRows, std::size_t leftCols, std::size_t rightRows, std::size_t rightCols, typename BackingType>
inline Matrix<leftRows, rightCols, BackingType> operator*(const Matrix<leftRows, leftCols, BackingType>& left, const Matrix<rightRows, rightCols, BackingType>& right)
{
    static_assert(leftCols == rightRows, "Matrix multiplications require that the left column count and the right row count match.");
    Matrix<leftRows, rightCols, BackingType> result;
    for (std::size_t i = 0; i < leftRows; ++i)
    {
        for (std::size_t j = 0; j < rightCols; ++j)
        {
            BackingType curItem = 0;
            for (std::size_t k = 0; k < leftCols; ++k)
            {
                curItem += left.Member(i, k) * right.Member(k, j);
            }
            result.Member(i, j) = curItem;
        }
    }
    return result;
}

template <std::size_t rows, std::size_t cols, typename BackingType>
inline Matrix<rows, cols, BackingType> operator*(BackingType val, const Matrix<rows, cols, BackingType>& target)
{
    Matrix<rows, cols, BackingType> result = target;
    for (std::size_t i = 0; i < rows; ++i)
    {
        for (std::size_t j = 0; j < cols; ++j)
        {
            result *= val;
        }
    }
    return result;
}

I have a (basically completed) matrix class (later in this post). If the matrix is a 1x1 matrix, then I'd like to have an implicit conversion to the backing type (e.g. a 1x1 float matrix should convert to a float).

Is there a way to do that without creating a specialization and duplicating all the methods inside Matrix? (e.g. using something like std::enable_if?) I basically want to enable the implicit conversion if and only if ROWS == COLS == 1.

template <std::size_t ROWS, std::size_t COLS = 1, typename BackingType = float>
class Matrix
{
    BackingType data[ROWS][COLS];
public:
    Matrix()
    {
        for(std::size_t rdx = 0; rdx < ROWS; ++rdx)
        {
            for (std::size_t cdx = 0; cdx < COLS; ++cdx)
            {
                data[rdx][cdx] = 0;
            }
        }
    }
    const BackingType& Member(std::size_t index) const
    {
        assert(index < ROWS*COLS);
        return *(static_cast<BackingType*>(&data[0][0]) + index);
    }
    BackingType& Member(std::size_t index)
    {
        assert(index < ROWS*COLS);
        return *(static_cast<BackingType*>(&data[0][0]) + index);
    }
    const BackingType& Member(std::size_t rowIndex, std::size_t colIndex) const
    {
        assert(rowIndex < ROWS);
        assert(colIndex < COLS);
        return data[rowIndex][colIndex];
    }
    BackingType& Member(std::size_t rowIndex, std::size_t colIndex)
    {
        assert(rowIndex < ROWS);
        assert(colIndex < COLS);
        return data[rowIndex][colIndex];
    }
    Matrix<COLS, ROWS, BackingType> Transpose() const
    {
        Matrix<COLS, ROWS, BackingType> result;
        for(std::size_t rowIdx = 0; rowIdx < ROWS; rowIdx++)
        {
            for (std::size_t colIdx = 0; colIdx < COLS; ++colIdx)
            {
                result.Member(colIdx, rowIdx) = Member(rowIdx, colIdx);
            }
        }
        return result;
    }
    template <std::size_t otherRows, std::size_t otherCols>
    Matrix<ROWS + otherRows, COLS, BackingType> AugmentBelow(const Matrix<otherRows, otherCols, BackingType>& other)
    {
        static_assert(COLS == otherCols, "Columns must match for a vertical augmentation.");
        Matrix<ROWS + otherRows, COLS, BackingType> result;
        for (std::size_t curRow = 0; curRow < ROWS; ++curRow)
        {
            for (std::size_t curCol = 0; curCol < COLS; ++curCol)
            {
                result.Member(curRow, curCol) = Member(curRow, curCol);
            }
        }
        for (std::size_t curRow = ROWS; curRow < (ROWS + otherRows); ++curRow)
        {
            for (std::size_t curCol = 0; curCol < COLS; ++curCol)
            {
                result.Member(curRow, curCol) = other.Member(curRow - ROWS, curCol);
            }
        }
        return result;
    }
    template <std::size_t otherRows, std::size_t otherCols>
    Matrix<ROWS, COLS + otherCols, BackingType> AugmentRight(const Matrix<otherRows, otherCols, BackingType>& other)
    {
        static_assert(ROWS == otherRows, "Rows must match for a vertical augmentation.");
        Matrix<ROWS, COLS + otherCols, BackingType> result;
        for (std::size_t curRow = 0; curRow < ROWS; ++curRow)
        {
            for (std::size_t curCol = 0; curCol < COLS; ++curCol)
            {
                result.Member(curRow, curCol) = Member(curRow, curCol);
            }
            for (std::size_t curCol = COLS; curCol < (COLS + otherCols); ++curCol)
            {
                result.Member(curRow, curCol) = other.Member(curRow, curCol - COLS);
            }
        }
        return result;
    }
    static Matrix<ROWS, COLS, BackingType> Identity()
    {
        static_assert(ROWS == COLS, "Identity matrices are always square.");
        Matrix<ROWS, COLS, BackingType> result;
        for (std::size_t diagonal = 0; diagonal < ROWS; ++diagonal)
        {
            result.Member(diagonal, diagonal) = 1;
        }
        return result;
    }
};

template <std::size_t leftRows, std::size_t leftCols, std::size_t rightRows, std::size_t rightCols, typename BackingType>
inline Matrix<leftRows, rightCols, BackingType> operator*(const Matrix<leftRows, leftCols, BackingType>& left, const Matrix<rightRows, rightCols, BackingType>& right)
{
    static_assert(leftCols == rightRows, "Matrix multiplications require that the left column count and the right row count match.");
    Matrix<leftRows, rightCols, BackingType> result;
    for (std::size_t i = 0; i < leftRows; ++i)
    {
        for (std::size_t j = 0; j < rightCols; ++j)
        {
            BackingType curItem = 0;
            for (std::size_t k = 0; k < leftCols; ++k)
            {
                curItem += left.Member(i, k) * right.Member(k, j);
            }
            result.Member(i, j) = curItem;
        }
    }
    return result;
}

template <std::size_t rows, std::size_t cols, typename BackingType>
inline Matrix<rows, cols, BackingType> operator*(BackingType val, const Matrix<rows, cols, BackingType>& target)
{
    Matrix<rows, cols, BackingType> result = target;
    for (std::size_t i = 0; i < rows; ++i)
    {
        for (std::size_t j = 0; j < cols; ++j)
        {
            result *= val;
        }
    }
    return result;
}

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(4

埖埖迣鎅 2024-12-16 03:59:46

另一种选择:

template<typename T, int Rows, int Cols>
struct matrix {
    template<
        // we need to 'duplicate' the template parameters
        // because SFINAE will only work for deduced parameters
        // and defaulted parameters count as deduced
        int R = Rows
        , int C = Cols

        // C++11 allows the use of SFINAE right here!
        , typename = typename std::enable_if<
            (R == 1 && C == 1)
        >::type
    >
    operator T() const;
};

An alternative:

template<typename T, int Rows, int Cols>
struct matrix {
    template<
        // we need to 'duplicate' the template parameters
        // because SFINAE will only work for deduced parameters
        // and defaulted parameters count as deduced
        int R = Rows
        , int C = Cols

        // C++11 allows the use of SFINAE right here!
        , typename = typename std::enable_if<
            (R == 1 && C == 1)
        >::type
    >
    operator T() const;
};
A君 2024-12-16 03:59:46

这是一种粗略的解决方法,使用 conditional 而不是 enable_if

#include <functional>

template <typename T, int N, int M>
struct Matrix
{
  struct IncompleteType;

  T buf[N * M];
  operator typename std::conditional<N == 1 && M == 1, T, IncompleteType<T>>::type () const
  {
    return buf[0];
  }
};

通过一些工作,我们可能也会使编译器错误变得更有意义。

Here's a crude hackaround, using conditional rather than enable_if:

#include <functional>

template <typename T, int N, int M>
struct Matrix
{
  struct IncompleteType;

  T buf[N * M];
  operator typename std::conditional<N == 1 && M == 1, T, IncompleteType<T>>::type () const
  {
    return buf[0];
  }
};

With some work one could probably make the compiler error a bit more meaningful, too.

入画浅相思 2024-12-16 03:59:46
template <typename T, int N, int M>
struct Matrix
{    
  T buf[N * M];
  operator typename std::conditional<N == 1 && M == 1, T, void>::type () const
  {
    return buf[0];
  }
};

在这种情况下,无需定义 IncompleteType。使用 void 就足够了,因为具有 void 类型的函数不应该返回任何值,但它会返回一些东西。这会导致替换失败并且 SFINAE 启动。

template <typename T, int N, int M>
struct Matrix
{    
  T buf[N * M];
  operator typename std::conditional<N == 1 && M == 1, T, void>::type () const
  {
    return buf[0];
  }
};

In this case there is no need to define IncompleteType. Using void suffices because a function with void type should not return any value, yet it returns something. This causes a substitution failure and SFINAE kicks in.

冷弦 2024-12-16 03:59:46

不,您不能将 enable_if 与隐式转换运算符一起使用,没有可以应用它的类型。将所有常见功能移至 matrix_base 模板类,然后让专业化继承它并添加特殊功能。另一种方法是无论如何实现该方法,并在其中放置一个静态断言,以便在实例化该方法时导致编译器错误。请注意,这将阻止您的类的显式实例化。

No, you can't use enable_if with an implicit conversion operators, there are no types to which you can apply it. Move all your common functionality to a matrix_base template class, and then have the specializations inherit from it and add the special functionality. The alternative is to implement the method anyway, and place a static assert within it to cause a compiler error if its instantiated. Note that this would prevent explicit instantiation of your class.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文