diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 7b92af01a..99f7ac40d 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -114,8 +114,6 @@ TORCHDYNAMO_XFAIL_SET = { "Matmul_dot", # %4 = torch.operator "aten.squeeze_.dim"(%3, %int0) : (!torch.tensor<*,f32>, !torch.int) -> !torch.tensor "Matmul_vecmat", - # ERROR: shape (torch.Size([2, 3, 4, 9])) is not equal to golden shape (torch.Size([2, 3, 6, 10])) - "UpSampleNearest2dDynamicFactor_basic", # https://github.com/llvm/torch-mlir/issues/1611 # error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible "Aten_EmbeddingBagExample_basic", diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index c0eecb67d..db61eb4a9 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -757,24 +757,6 @@ public: }; } // namespace -// `getScaledDims` scales the `dim` value with a scale factor `ScaleFactor`. -// The `dim` and `scaleFactor` are assumed to be of index and float type -// respectively. `scaledDim = int(floor(float(dim) * scaleFactor))`. -static Value getScaledDims(OpBuilder &builder, Location loc, Value dim, - Value scaleFactor) { - - Value dimInt = castIndexToInt64(builder, loc, dim); - Value dimFp = - builder.create(loc, scaleFactor.getType(), dimInt); - Value scaleDim = builder.create(loc, dimFp, scaleFactor); - Value floorDim = builder.create(loc, scaleDim); - Value scaledDimToIndex = castIntToIndex( - builder, loc, - builder.create(loc, dimInt.getType(), floorDim)); - - return scaledDimToIndex; -} - // `getScaleFactor` returns the scale factor from input to output dimension. // The `dim` and `scaledDim` are assumed to be of index and int64 type // respectively. scale_factor = (scaled_dim // dim). @@ -820,6 +802,9 @@ public: // The dimension at which the scaling starts. unsigned hDimOffset = 2; + Value originalHeight = dims[hDimOffset]; + Value originalWidth = dims[hDimOffset + 1]; + SmallVector outputSizeTorchInt; if (!getListConstructElements(op.output_size(), outputSizeTorchInt)) return rewriter.notifyMatchFailure( @@ -828,42 +813,39 @@ public: SmallVector outputSizeIntValues; outputSizeIntValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), outputSizeTorchInt); - + if (!op.scales_h().getType().isa()) { // Convert float values to int values. // int_value = (int64_t)ceil(float_value) Value ceilVal = rewriter.create(loc, adaptor.scales_h()); - Value intVal = rewriter.create( - loc, rewriter.getI64Type(), ceilVal); + Value intVal = + rewriter.create(loc, rewriter.getI64Type(), ceilVal); scaleFactorsInt.push_back(intVal); - dims[hDimOffset] = getScaledDims( - rewriter, loc, dims[hDimOffset], adaptor.scales_h()); } else { - auto scaleFactorVal = getScaleFactor( - rewriter, loc, dims[hDimOffset], outputSizeIntValues[0]); + auto scaleFactorVal = + getScaleFactor(rewriter, loc, originalHeight, outputSizeIntValues[0]); scaleFactorsInt.push_back(scaleFactorVal); - dims[hDimOffset] = - castIntToIndex(rewriter, loc, outputSizeIntValues[0]); } if (!op.scales_w().getType().isa()) { // Convert float values to int values. // int_value = (int64_t)ceil(float_value) Value ceilVal = rewriter.create(loc, adaptor.scales_w()); - Value intVal = rewriter.create( - loc, rewriter.getI64Type(), ceilVal); + Value intVal = + rewriter.create(loc, rewriter.getI64Type(), ceilVal); scaleFactorsInt.push_back(intVal); - dims[hDimOffset + 1] = getScaledDims( - rewriter, loc, dims[hDimOffset + 1], adaptor.scales_w()); } else { - auto scaleFactorVal = getScaleFactor( - rewriter, loc, dims[hDimOffset + 1], outputSizeIntValues[1]); + auto scaleFactorVal = + getScaleFactor(rewriter, loc, originalWidth, outputSizeIntValues[1]); scaleFactorsInt.push_back(scaleFactorVal); - dims[hDimOffset + 1] = - castIntToIndex(rewriter, loc, outputSizeIntValues[1]); } - + // The output size is always as provided by `output_size`. However, the + // scaling is determined by the `scales_h` and `scales_w` if provided. + dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]); + dims[hDimOffset + 1] = + castIntToIndex(rewriter, loc, outputSizeIntValues[1]); + Value outTensor = rewriter.create( loc, getAsOpFoldResult(dims), elementType);