diff --git a/runtime/mem_obj/image.cpp b/runtime/mem_obj/image.cpp index e30d55a512..12a6a4631f 100644 --- a/runtime/mem_obj/image.cpp +++ b/runtime/mem_obj/image.cpp @@ -417,7 +417,6 @@ cl_int Image::validate(Context *context, const cl_image_desc *imageDesc, const void *hostPtr) { auto pDevice = context->getDevice(0); - cl_int retVal = CL_SUCCESS; size_t srcSize = 0; size_t retSize = 0; const size_t *maxWidth = nullptr; @@ -430,14 +429,12 @@ cl_int Image::validate(Context *context, Image *parentImage = castToObject(imageDesc->mem_object); Buffer *parentBuffer = castToObject(imageDesc->mem_object); - switch (imageDesc->image_type) { - - case CL_MEM_OBJECT_IMAGE2D: + if (imageDesc->image_type == CL_MEM_OBJECT_IMAGE2D) { pDevice->getCap(reinterpret_cast(maxWidth), srcSize, retSize); pDevice->getCap(reinterpret_cast(maxHeight), srcSize, retSize); if (imageDesc->image_width > *maxWidth || imageDesc->image_height > *maxHeight) { - retVal = CL_INVALID_IMAGE_SIZE; + return CL_INVALID_IMAGE_SIZE; } if (parentBuffer) { // Image 2d from buffer pDevice->getCap(reinterpret_cast(pitchAlignment), srcSize, retSize); @@ -446,52 +443,43 @@ cl_int Image::validate(Context *context, if ((imageDesc->image_row_pitch % (*pitchAlignment)) || ((parentBuffer->getFlags() & CL_MEM_USE_HOST_PTR) && (reinterpret_cast(parentBuffer->getHostPtr()) % (*baseAddressAlignment))) || (imageDesc->image_height * (imageDesc->image_row_pitch != 0 ? imageDesc->image_row_pitch : imageDesc->image_width) > parentBuffer->getSize())) { - retVal = CL_INVALID_IMAGE_FORMAT_DESCRIPTOR; + return CL_INVALID_IMAGE_FORMAT_DESCRIPTOR; } else if (flags & (CL_MEM_USE_HOST_PTR | CL_MEM_COPY_HOST_PTR)) { - retVal = CL_INVALID_VALUE; + return CL_INVALID_VALUE; } } if (parentImage && !IsNV12Image(&parentImage->getImageFormat())) { // Image 2d from image 2d if (!parentImage->hasSameDescriptor(*imageDesc) || !parentImage->hasValidParentImageFormat(surfaceFormat->OCLImageFormat)) { - retVal = CL_INVALID_IMAGE_FORMAT_DESCRIPTOR; + return CL_INVALID_IMAGE_FORMAT_DESCRIPTOR; } } - if (!imageDesc->mem_object && (imageDesc->image_width == 0 || - imageDesc->image_height == 0)) { - retVal = CL_INVALID_IMAGE_DESCRIPTOR; + if (!(parentImage && IsNV12Image(&parentImage->getImageFormat())) && + (imageDesc->image_width == 0 || imageDesc->image_height == 0)) { + return CL_INVALID_IMAGE_DESCRIPTOR; } - break; - default: - break; } if (hostPtr == nullptr) { if (imageDesc->image_row_pitch != 0 && imageDesc->mem_object == nullptr) { - retVal = CL_INVALID_IMAGE_DESCRIPTOR; + return CL_INVALID_IMAGE_DESCRIPTOR; } } else { if (imageDesc->image_row_pitch != 0) { if (imageDesc->image_row_pitch % surfaceFormat->ImageElementSizeInBytes != 0 || imageDesc->image_row_pitch < imageDesc->image_width * surfaceFormat->ImageElementSizeInBytes) { - retVal = CL_INVALID_IMAGE_DESCRIPTOR; + return CL_INVALID_IMAGE_DESCRIPTOR; } } } if (parentBuffer && imageDesc->image_type != CL_MEM_OBJECT_IMAGE1D_BUFFER && imageDesc->image_type != CL_MEM_OBJECT_IMAGE2D) { - retVal = CL_INVALID_IMAGE_DESCRIPTOR; + return CL_INVALID_IMAGE_DESCRIPTOR; } if (parentImage && imageDesc->image_type != CL_MEM_OBJECT_IMAGE2D) { - retVal = CL_INVALID_IMAGE_DESCRIPTOR; + return CL_INVALID_IMAGE_DESCRIPTOR; } - if (retVal != CL_SUCCESS) { - return retVal; - } - - retVal = validateImageTraits(context, flags, &surfaceFormat->OCLImageFormat, imageDesc, hostPtr); - - return retVal; + return validateImageTraits(context, flags, &surfaceFormat->OCLImageFormat, imageDesc, hostPtr); } cl_int Image::validateImageFormat(const cl_image_format *imageFormat) { diff --git a/unit_tests/mem_obj/image2d_from_buffer_tests.cpp b/unit_tests/mem_obj/image2d_from_buffer_tests.cpp index 00ded62a9e..8b4f2e4dce 100644 --- a/unit_tests/mem_obj/image2d_from_buffer_tests.cpp +++ b/unit_tests/mem_obj/image2d_from_buffer_tests.cpp @@ -105,13 +105,13 @@ TEST_F(Image2dFromBufferTest, CalculateRowPitch) { EXPECT_EQ(1024u, imageFromBuffer->getImageDesc().image_row_pitch); delete imageFromBuffer; } -TEST_F(Image2dFromBufferTest, InvalidRowPitch) { +TEST_F(Image2dFromBufferTest, givenInvalidRowPitchWhenCreateImage2dFromBufferThenReturnsError) { char ptr[10]; imageDesc.image_row_pitch = 257; cl_mem_flags flags = CL_MEM_READ_ONLY; auto surfaceFormat = (SurfaceFormatInfo *)Image::getSurfaceFormatFromTable(flags, &imageFormat); retVal = Image::validate(&context, flags, surfaceFormat, &imageDesc, ptr); - EXPECT_EQ(CL_INVALID_IMAGE_DESCRIPTOR, retVal); + EXPECT_EQ(CL_INVALID_IMAGE_FORMAT_DESCRIPTOR, retVal); } TEST_F(Image2dFromBufferTest, givenRowPitchThatIsGreaterThenComputedWhenImageIsCreatedThenPassedRowPitchIsUsedInsteadOfComputed) { diff --git a/unit_tests/mem_obj/image_validate_tests.cpp b/unit_tests/mem_obj/image_validate_tests.cpp index a154e5ccf8..07843607a5 100644 --- a/unit_tests/mem_obj/image_validate_tests.cpp +++ b/unit_tests/mem_obj/image_validate_tests.cpp @@ -925,3 +925,24 @@ TEST(ImageValidatorTest, givenNV12Image2dAsParentImageWhenValidateImageZeroSized EXPECT_EQ(CL_SUCCESS, Image::validate(&context, 0, &surfaceFormat, &descriptor, dummyPtr)); }; +TEST(ImageValidatorTest, givenNonNV12Image2dAsParentImageWhenValidateImageZeroSizedThenReturnsError) { + NullImage image; + cl_image_desc descriptor; + MockContext context; + void *dummyPtr = reinterpret_cast(0x17); + SurfaceFormatInfo surfaceFormat; + image.imageFormat.image_channel_order = CL_BGRA; + image.imageFormat.image_channel_data_type = CL_UNORM_INT8; + + surfaceFormat.OCLImageFormat.image_channel_order = CL_sBGRA; + surfaceFormat.OCLImageFormat.image_channel_data_type = CL_UNORM_INT8; + descriptor.image_type = CL_MEM_OBJECT_IMAGE2D; + descriptor.image_height = 0; + descriptor.image_width = 0; + descriptor.image_row_pitch = image.getHostPtrRowPitch(); + descriptor.image_slice_pitch = image.getHostPtrSlicePitch(); + image.imageDesc = descriptor; + descriptor.mem_object = ℑ + + EXPECT_EQ(CL_INVALID_IMAGE_DESCRIPTOR, Image::validate(&context, 0, &surfaceFormat, &descriptor, dummyPtr)); +};