diff --git a/runtime/core/portable_type/tensor_impl.cpp b/runtime/core/portable_type/tensor_impl.cpp index ede5a3d4101..64ecb73a726 100644 --- a/runtime/core/portable_type/tensor_impl.cpp +++ b/runtime/core/portable_type/tensor_impl.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -38,7 +39,9 @@ ssize_t compute_numel(const TensorImpl::SizesType* sizes, ssize_t dim) { "Size must be non-negative, got %zd at dimension %zd", static_cast(sizes[i]), i); - numel *= sizes[i]; + bool overflow = + c10::mul_overflows(numel, static_cast(sizes[i]), &numel); + ET_CHECK_MSG(!overflow, "numel overflowed"); } return numel; } @@ -66,7 +69,11 @@ TensorImpl::TensorImpl( } size_t TensorImpl::nbytes() const { - return numel_ * elementSize(type_); + size_t result; + bool overflow = c10::mul_overflows( + static_cast(numel_), elementSize(type_), &result); + ET_CHECK_MSG(!overflow, "nbytes overflowed"); + return result; } // Return the size of one element of the tensor diff --git a/runtime/core/portable_type/test/tensor_impl_test.cpp b/runtime/core/portable_type/test/tensor_impl_test.cpp index 0b8ae05f4da..08b26c64e09 100644 --- a/runtime/core/portable_type/test/tensor_impl_test.cpp +++ b/runtime/core/portable_type/test/tensor_impl_test.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -449,3 +450,11 @@ TEST_F(TensorImplTest, TestResizingTensorToZeroAndBack) { EXPECT_GT(t.numel(), 0); EXPECT_EQ(t.data(), data); } + +TEST_F(TensorImplTest, TestNbytesOverflow) { + SizesType sizes[3] = { + static_cast(1 << 21), + static_cast(1 << 21), + static_cast(1 << 21)}; + ET_EXPECT_DEATH(TensorImpl t(ScalarType::Float, 3, sizes), ""); +}