mirror of https://github.com/llvm/torch-mlir
[torchdynamo] Fix output size computation for upsample_nearest2d
parent
883b986eda
commit
ecb09c2fc3
|
@ -114,8 +114,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"Matmul_dot",
|
||||
# %4 = torch.operator "aten.squeeze_.dim"(%3, %int0) : (!torch.tensor<*,f32>, !torch.int) -> !torch.tensor
|
||||
"Matmul_vecmat",
|
||||
# ERROR: shape (torch.Size([2, 3, 4, 9])) is not equal to golden shape (torch.Size([2, 3, 6, 10]))
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
# https://github.com/llvm/torch-mlir/issues/1611
|
||||
# error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible
|
||||
"Aten_EmbeddingBagExample_basic",
|
||||
|
|
|
@ -757,24 +757,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// `getScaledDims` scales the `dim` value with a scale factor `ScaleFactor`.
|
||||
// The `dim` and `scaleFactor` are assumed to be of index and float type
|
||||
// respectively. `scaledDim = int(floor(float(dim) * scaleFactor))`.
|
||||
static Value getScaledDims(OpBuilder &builder, Location loc, Value dim,
|
||||
Value scaleFactor) {
|
||||
|
||||
Value dimInt = castIndexToInt64(builder, loc, dim);
|
||||
Value dimFp =
|
||||
builder.create<arith::SIToFPOp>(loc, scaleFactor.getType(), dimInt);
|
||||
Value scaleDim = builder.create<arith::MulFOp>(loc, dimFp, scaleFactor);
|
||||
Value floorDim = builder.create<math::FloorOp>(loc, scaleDim);
|
||||
Value scaledDimToIndex = castIntToIndex(
|
||||
builder, loc,
|
||||
builder.create<arith::FPToSIOp>(loc, dimInt.getType(), floorDim));
|
||||
|
||||
return scaledDimToIndex;
|
||||
}
|
||||
|
||||
// `getScaleFactor` returns the scale factor from input to output dimension.
|
||||
// The `dim` and `scaledDim` are assumed to be of index and int64 type
|
||||
// respectively. scale_factor = (scaled_dim // dim).
|
||||
|
@ -820,6 +802,9 @@ public:
|
|||
// The dimension at which the scaling starts.
|
||||
unsigned hDimOffset = 2;
|
||||
|
||||
Value originalHeight = dims[hDimOffset];
|
||||
Value originalWidth = dims[hDimOffset + 1];
|
||||
|
||||
SmallVector<Value, 2> outputSizeTorchInt;
|
||||
if (!getListConstructElements(op.output_size(), outputSizeTorchInt))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -828,42 +813,39 @@ public:
|
|||
SmallVector<Value, 2> outputSizeIntValues;
|
||||
outputSizeIntValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
||||
|
||||
|
||||
if (!op.scales_h().getType().isa<Torch::NoneType>()) {
|
||||
// Convert float values to int values.
|
||||
// int_value = (int64_t)ceil(float_value)
|
||||
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.scales_h());
|
||||
Value intVal = rewriter.create<arith::FPToSIOp>(
|
||||
loc, rewriter.getI64Type(), ceilVal);
|
||||
Value intVal =
|
||||
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), ceilVal);
|
||||
scaleFactorsInt.push_back(intVal);
|
||||
dims[hDimOffset] = getScaledDims(
|
||||
rewriter, loc, dims[hDimOffset], adaptor.scales_h());
|
||||
} else {
|
||||
auto scaleFactorVal = getScaleFactor(
|
||||
rewriter, loc, dims[hDimOffset], outputSizeIntValues[0]);
|
||||
auto scaleFactorVal =
|
||||
getScaleFactor(rewriter, loc, originalHeight, outputSizeIntValues[0]);
|
||||
scaleFactorsInt.push_back(scaleFactorVal);
|
||||
dims[hDimOffset] =
|
||||
castIntToIndex(rewriter, loc, outputSizeIntValues[0]);
|
||||
}
|
||||
|
||||
if (!op.scales_w().getType().isa<Torch::NoneType>()) {
|
||||
// Convert float values to int values.
|
||||
// int_value = (int64_t)ceil(float_value)
|
||||
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.scales_w());
|
||||
Value intVal = rewriter.create<arith::FPToSIOp>(
|
||||
loc, rewriter.getI64Type(), ceilVal);
|
||||
Value intVal =
|
||||
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), ceilVal);
|
||||
scaleFactorsInt.push_back(intVal);
|
||||
dims[hDimOffset + 1] = getScaledDims(
|
||||
rewriter, loc, dims[hDimOffset + 1], adaptor.scales_w());
|
||||
} else {
|
||||
auto scaleFactorVal = getScaleFactor(
|
||||
rewriter, loc, dims[hDimOffset + 1], outputSizeIntValues[1]);
|
||||
auto scaleFactorVal =
|
||||
getScaleFactor(rewriter, loc, originalWidth, outputSizeIntValues[1]);
|
||||
scaleFactorsInt.push_back(scaleFactorVal);
|
||||
dims[hDimOffset + 1] =
|
||||
castIntToIndex(rewriter, loc, outputSizeIntValues[1]);
|
||||
}
|
||||
|
||||
|
||||
// The output size is always as provided by `output_size`. However, the
|
||||
// scaling is determined by the `scales_h` and `scales_w` if provided.
|
||||
dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]);
|
||||
dims[hDimOffset + 1] =
|
||||
castIntToIndex(rewriter, loc, outputSizeIntValues[1]);
|
||||
|
||||
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(dims), elementType);
|
||||
|
||||
|
|
Loading…
Reference in New Issue