mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix aten.upsample_nearest2d op
-- aten.upsample_nearest2d.vec op is not present owing to https://github.com/pytorch/pytorch/pull/85638 -- So this commit adds a lowering on aten.upsample_nearest2d. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>pull/1612/head snapshot-20221118.661
parent
638a884e8c
commit
1d949f3ac2
|
@ -12,9 +12,7 @@
|
|||
|
||||
from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
}
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
EAGER_MODE_XFAIL_SET = {
|
||||
# RefBackend fails
|
||||
|
@ -22,6 +20,7 @@ EAGER_MODE_XFAIL_SET = {
|
|||
"QuantizedMLP_basic",
|
||||
"Matmul_vecmat",
|
||||
"BatchMlpLayerModule_basic",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
}
|
||||
|
||||
MHLO_PASS_SET = {
|
||||
|
@ -613,10 +612,6 @@ LTC_XFAIL_SET = {
|
|||
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||
"AtenIntTensorByteDtypeModule_basic",
|
||||
"AtenIntTensorCharDtypeModule_basic",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"UpSampleNearest2dDynamicSize_basic",
|
||||
"UpSampleNearest2dStaticFactor_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
"Fill_TensorFloat32WithFloat32_basic",
|
||||
"Fill_TensorFloat32WithFloat64_basic",
|
||||
"Fill_TensorFloat32WithInt64_basic",
|
||||
|
|
|
@ -7968,27 +7968,28 @@ def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [
|
||||
def Torch_AtenUpsampleNearest2dOp : Torch_Op<"aten.upsample_nearest2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchOptionalListOfTorchIntType:$output_size,
|
||||
AnyTorchOptionalListOfTorchFloatType:$scale_factors
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$output_size,
|
||||
AnyTorchOptionalFloatType:$scales_h,
|
||||
AnyTorchOptionalFloatType:$scales_w
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenUpsampleNearest2dVecOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
ParseResult AtenUpsampleNearest2dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenUpsampleNearest2dVecOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
void AtenUpsampleNearest2dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -797,17 +797,17 @@ static Value getScaleFactor(OpBuilder &builder, Location loc, Value dim,
|
|||
// out_tensor[i, j, k, l] = input[i, j, k//H_factor, l//W_factor]
|
||||
|
||||
namespace {
|
||||
class ConvertAtenUpsampleNearest2dVecOp
|
||||
: public OpConversionPattern<AtenUpsampleNearest2dVecOp> {
|
||||
class ConvertAtenUpsampleNearest2dOp
|
||||
: public OpConversionPattern<AtenUpsampleNearest2dOp> {
|
||||
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenUpsampleNearest2dVecOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(AtenUpsampleNearest2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.input();
|
||||
Value input = adaptor.self();
|
||||
|
||||
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
|
@ -820,48 +820,50 @@ public:
|
|||
// The dimension at which the scaling starts.
|
||||
unsigned hDimOffset = 2;
|
||||
|
||||
if (!adaptor.scale_factors().getType().isa<Torch::NoneType>()) {
|
||||
SmallVector<Value, 2> scaleFactorsTorchFloat;
|
||||
if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the scale_factors is not constructed from "
|
||||
"ListConstruct");
|
||||
SmallVector<Value, 2> scaleFactorsFloatValues;
|
||||
scaleFactorsFloatValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat);
|
||||
SmallVector<Value, 2> outputSizeTorchInt;
|
||||
if (!getListConstructElements(op.output_size(), outputSizeTorchInt))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the output_size is not constructed from "
|
||||
"ListConstruct");
|
||||
SmallVector<Value, 2> outputSizeIntValues;
|
||||
outputSizeIntValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
||||
|
||||
if (!op.scales_h().getType().isa<Torch::NoneType>()) {
|
||||
// Convert float values to int values.
|
||||
// int_value = (int64_t)ceil(float_value)
|
||||
for (auto floatValue : scaleFactorsFloatValues) {
|
||||
Value ceilVal = rewriter.create<math::CeilOp>(loc, floatValue);
|
||||
Value intVal = rewriter.create<arith::FPToSIOp>(
|
||||
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.scales_h());
|
||||
Value intVal = rewriter.create<arith::FPToSIOp>(
|
||||
loc, rewriter.getI64Type(), ceilVal);
|
||||
scaleFactorsInt.push_back(intVal);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < scaleFactorsFloatValues.size(); i++)
|
||||
dims[hDimOffset + i] = getScaledDims(
|
||||
rewriter, loc, dims[hDimOffset + i], scaleFactorsFloatValues[i]);
|
||||
|
||||
scaleFactorsInt.push_back(intVal);
|
||||
dims[hDimOffset] = getScaledDims(
|
||||
rewriter, loc, dims[hDimOffset], adaptor.scales_h());
|
||||
} else {
|
||||
|
||||
SmallVector<Value, 2> outputSizeTorchInt;
|
||||
if (!getListConstructElements(op.output_size(), outputSizeTorchInt))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the output_size is not constructed from "
|
||||
"ListConstruct");
|
||||
SmallVector<Value, 2> outputSizeIntValues;
|
||||
outputSizeIntValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
||||
|
||||
for (unsigned i = 0; i < outputSizeTorchInt.size(); i++) {
|
||||
auto scaleFactorVal = getScaleFactor(
|
||||
rewriter, loc, dims[hDimOffset + i], outputSizeIntValues[i]);
|
||||
scaleFactorsInt.push_back(scaleFactorVal);
|
||||
dims[hDimOffset + i] =
|
||||
castIntToIndex(rewriter, loc, outputSizeIntValues[i]);
|
||||
}
|
||||
auto scaleFactorVal = getScaleFactor(
|
||||
rewriter, loc, dims[hDimOffset], outputSizeIntValues[0]);
|
||||
scaleFactorsInt.push_back(scaleFactorVal);
|
||||
dims[hDimOffset] =
|
||||
castIntToIndex(rewriter, loc, outputSizeIntValues[0]);
|
||||
}
|
||||
|
||||
if (!op.scales_w().getType().isa<Torch::NoneType>()) {
|
||||
// Convert float values to int values.
|
||||
// int_value = (int64_t)ceil(float_value)
|
||||
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.scales_w());
|
||||
Value intVal = rewriter.create<arith::FPToSIOp>(
|
||||
loc, rewriter.getI64Type(), ceilVal);
|
||||
scaleFactorsInt.push_back(intVal);
|
||||
dims[hDimOffset + 1] = getScaledDims(
|
||||
rewriter, loc, dims[hDimOffset + 1], adaptor.scales_w());
|
||||
} else {
|
||||
auto scaleFactorVal = getScaleFactor(
|
||||
rewriter, loc, dims[hDimOffset + 1], outputSizeIntValues[1]);
|
||||
scaleFactorsInt.push_back(scaleFactorVal);
|
||||
dims[hDimOffset + 1] =
|
||||
castIntToIndex(rewriter, loc, outputSizeIntValues[1]);
|
||||
}
|
||||
|
||||
|
||||
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(dims), elementType);
|
||||
|
||||
|
@ -1103,8 +1105,8 @@ void mlir::torch::torch_to_linalg::
|
|||
patterns.add<ConvertAtenIndexTensorOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenEmbeddingBagPaddingIdxOp>();
|
||||
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
|
||||
patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUpsampleNearest2dOp>();
|
||||
patterns.add<ConvertAtenUpsampleNearest2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUpsampleNearest2dBackwardOp>();
|
||||
patterns.add<ConvertAtenUpsampleNearest2dBackwardOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -700,7 +700,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
|
||||
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
|
||||
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
|
||||
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
|
||||
AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
|
||||
AtenUpsampleNearest2dBackwardOp>(op)) {
|
||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
}
|
||||
|
|
|
@ -7045,9 +7045,15 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.upsample_nearest2d(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<float>>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %3 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %4 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"";
|
||||
|
|
|
@ -1226,8 +1226,8 @@ def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[Lis
|
|||
def aten〇frobenius_norm〇dim(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
|
||||
|
||||
def aten〇upsample_nearest2d〇vec(input: List[int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> List[int]:
|
||||
return upstream_shape_functions.upsample_nearest2d(input, output_size, scale_factors)
|
||||
def aten〇upsample_nearest2d(self: List[int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
|
||||
return [self[0], self[1], output_size[0], output_size[1]]
|
||||
|
||||
# ==============================================================================
|
||||
# Shape library generator main().
|
||||
|
|
|
@ -528,7 +528,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
|
||||
emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
|
||||
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
|
||||
|
||||
|
||||
# Dict ops.
|
||||
|
|
|
@ -714,6 +714,27 @@ class Conv_Transpose3dModule(torch.nn.Module):
|
|||
def Conv_Transpose3dModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(5, 2, 5, 6, 4), torch.randn(2, 5, 2, 2, 2))
|
||||
|
||||
class UpSampleNearest2d(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.upsample_nearest2d(input,
|
||||
output_size=[18, 48],
|
||||
scales_h=3.0,
|
||||
scales_w=4.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2d())
|
||||
def UpSampleNearest2d_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
|
||||
|
||||
class UpSampleNearest2dSameSize(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -727,7 +748,8 @@ class UpSampleNearest2dSameSize(torch.nn.Module):
|
|||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=[11, 11],
|
||||
scale_factors=None)
|
||||
scales_h=None,
|
||||
scales_w=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dSameSize())
|
||||
|
@ -745,7 +767,8 @@ class UpSampleNearest2dDiffSize(torch.nn.Module):
|
|||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=[8, 11],
|
||||
scale_factors=None)
|
||||
scales_h=None,
|
||||
scales_w=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dDiffSize())
|
||||
|
@ -762,8 +785,9 @@ class UpSampleNearest2dDiffFactor(torch.nn.Module):
|
|||
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=None,
|
||||
scale_factors=[2.3, 4.7])
|
||||
output_size=[6, 10],
|
||||
scales_h=2.3,
|
||||
scales_w=4.7)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dDiffFactor())
|
||||
|
@ -783,8 +807,9 @@ class UpSampleNearest2dSameFactor(torch.nn.Module):
|
|||
])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=None,
|
||||
scale_factors=[2.0, 2.0])
|
||||
output_size=[8, 8],
|
||||
scales_h=2.0,
|
||||
scales_w=2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dSameFactor())
|
||||
|
|
Loading…
Reference in New Issue