mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch lowering for Onnx.Upsample Op (#3371)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3408/head
parent
09f502667b
commit
3c3fbe4680
|
@ -152,6 +152,55 @@ LogicalResult reducedSumImpl(OpBinder binder,
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value getValueList(OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||||
|
Value operand) {
|
||||||
|
SmallVector<Value> itemList;
|
||||||
|
auto sizes = dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
|
||||||
|
Torch::BaseTensorType operandType =
|
||||||
|
cast<Torch::BaseTensorType>(operand.getType());
|
||||||
|
|
||||||
|
SmallVector<int64_t> selectSizes;
|
||||||
|
selectSizes.push_back(1);
|
||||||
|
Type selectResultType = operandType.getWithSizesAndDtype(
|
||||||
|
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());
|
||||||
|
|
||||||
|
auto extract = [&rewriter, &binder](Value x, Value v) {
|
||||||
|
auto xTy = cast<Torch::ValueTensorType>(x.getType());
|
||||||
|
Type extractTy = rewriter.getType<Torch::FloatType>();
|
||||||
|
if (isa<IntegerType>(xTy.getDtype()))
|
||||||
|
extractTy = rewriter.getType<Torch::IntType>();
|
||||||
|
|
||||||
|
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy, v);
|
||||||
|
};
|
||||||
|
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
|
||||||
|
MLIRContext *context = binder.op->getContext();
|
||||||
|
for (int i = 2; i < sizes[0]; i++) {
|
||||||
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
binder.getLoc(), selectResultType, operand, zero, selectIndex);
|
||||||
|
Value item = extract(operand, ext);
|
||||||
|
itemList.push_back(item);
|
||||||
|
}
|
||||||
|
auto xTy = cast<Torch::ValueTensorType>(operand.getType());
|
||||||
|
Value ValueList;
|
||||||
|
if (isa<IntegerType>(xTy.getDtype())) {
|
||||||
|
ValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(), Torch::ListType::get(Torch::IntType::get(context)),
|
||||||
|
itemList);
|
||||||
|
} else {
|
||||||
|
ValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(), Torch::ListType::get(Torch::FloatType::get(context)),
|
||||||
|
itemList);
|
||||||
|
}
|
||||||
|
return ValueList;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
@ -2830,62 +2879,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
.getSizes()
|
.getSizes()
|
||||||
.size();
|
.size();
|
||||||
|
|
||||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
||||||
|
|
||||||
Value cstFalse =
|
Value cstFalse =
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||||
Value cstTrue =
|
Value cstTrue =
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||||
Value modeStrValue;
|
Value modeStrValue;
|
||||||
|
|
||||||
auto extract = [&rewriter, &binder](Value x, Value v) {
|
|
||||||
auto xTy = cast<Torch::ValueTensorType>(x.getType());
|
|
||||||
Type extractTy = rewriter.getType<Torch::FloatType>();
|
|
||||||
if (isa<IntegerType>(xTy.getDtype()))
|
|
||||||
extractTy = rewriter.getType<Torch::IntType>();
|
|
||||||
|
|
||||||
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
|
|
||||||
v);
|
|
||||||
};
|
|
||||||
|
|
||||||
auto getValueList = [&](Value operand) {
|
|
||||||
SmallVector<Value> itemList;
|
|
||||||
auto sizes =
|
|
||||||
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
|
|
||||||
Torch::BaseTensorType operandType =
|
|
||||||
cast<Torch::BaseTensorType>(operand.getType());
|
|
||||||
|
|
||||||
SmallVector<int64_t> selectSizes;
|
|
||||||
selectSizes.push_back(1);
|
|
||||||
Type selectResultType = operandType.getWithSizesAndDtype(
|
|
||||||
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());
|
|
||||||
|
|
||||||
MLIRContext *context = binder.op->getContext();
|
|
||||||
for (int i = 2; i < sizes[0]; i++) {
|
|
||||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
||||||
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
|
|
||||||
binder.getLoc(), selectResultType, operand, zero, selectIndex);
|
|
||||||
Value item = extract(operand, ext);
|
|
||||||
itemList.push_back(item);
|
|
||||||
}
|
|
||||||
auto xTy = cast<Torch::ValueTensorType>(operand.getType());
|
|
||||||
Value ValueList;
|
|
||||||
if (isa<IntegerType>(xTy.getDtype())) {
|
|
||||||
ValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
binder.getLoc(),
|
|
||||||
Torch::ListType::get(Torch::IntType::get(context)), itemList);
|
|
||||||
} else {
|
|
||||||
ValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
binder.getLoc(),
|
|
||||||
Torch::ListType::get(Torch::FloatType::get(context)), itemList);
|
|
||||||
}
|
|
||||||
return ValueList;
|
|
||||||
};
|
|
||||||
|
|
||||||
Value scalesValueList = noneVal;
|
Value scalesValueList = noneVal;
|
||||||
Value sizesValueList = noneVal;
|
Value sizesValueList = noneVal;
|
||||||
Value alignCorners =
|
Value alignCorners =
|
||||||
|
@ -2934,12 +2933,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
}
|
}
|
||||||
if (operands.size() < 4) {
|
if (operands.size() < 4) {
|
||||||
Value scaleOperand = operands[2];
|
Value scaleOperand = operands[2];
|
||||||
scalesValueList = getValueList(scaleOperand);
|
scalesValueList = getValueList(binder, rewriter, scaleOperand);
|
||||||
sizesValueList = noneVal;
|
sizesValueList = noneVal;
|
||||||
} else {
|
} else {
|
||||||
Value sizeOperand = operands[3];
|
Value sizeOperand = operands[3];
|
||||||
scalesValueList = noneVal;
|
scalesValueList = noneVal;
|
||||||
sizesValueList = getValueList(sizeOperand);
|
sizesValueList = getValueList(binder, rewriter, sizeOperand);
|
||||||
}
|
}
|
||||||
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
|
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
|
||||||
isa<Torch::NoneType>(sizesValueList.getType())) {
|
isa<Torch::NoneType>(sizesValueList.getType())) {
|
||||||
|
@ -3258,4 +3257,47 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
rewriter.replaceOp(binder.op, inputSequence);
|
rewriter.replaceOp(binder.op, inputSequence);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"Upsample", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
std::string mode;
|
||||||
|
Value input, scales;
|
||||||
|
if (binder.tensorOperands(input, scales) ||
|
||||||
|
binder.customOpNameStringAttr(mode, "mode", "nearest") ||
|
||||||
|
binder.tensorResultType(resultType)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mode != "nearest" && mode != "linear")
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unsupported interpolation mode other than nearest, "
|
||||||
|
"linear");
|
||||||
|
|
||||||
|
int64_t resultRank = resultType.getSizes().size();
|
||||||
|
if (resultRank > 5)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "supports upto 3d upsampling only");
|
||||||
|
|
||||||
|
Value scalesValueList = getValueList(binder, rewriter, scales);
|
||||||
|
if (mode == "linear") {
|
||||||
|
if (resultRank == 4)
|
||||||
|
mode = "bilinear";
|
||||||
|
if (resultRank == 5)
|
||||||
|
mode = "trilinear";
|
||||||
|
}
|
||||||
|
Value modeStrValue =
|
||||||
|
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), mode);
|
||||||
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
|
binder.getLoc(), rewriter.getBoolAttr(false));
|
||||||
|
|
||||||
|
rewriter
|
||||||
|
.replaceOpWithNewOp<Torch::Aten__InterpolateSizeListScaleListOp>(
|
||||||
|
binder.op, resultType, input, /*size=*/cstNone, scalesValueList,
|
||||||
|
modeStrValue,
|
||||||
|
/* AnyTorchOptionalBoolType:$align_corners */ cstNone,
|
||||||
|
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ cstNone,
|
||||||
|
/*Torch_BoolType:$antialias*/ cstFalse);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -2541,3 +2541,45 @@ func.func @test_sequence_empty() -> !torch.list<vtensor<[],f32>> attributes {tor
|
||||||
%0 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list<vtensor<[],f32>>
|
%0 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list<vtensor<[],f32>>
|
||||||
return %0 : !torch.list<vtensor<[],f32>>
|
return %0 : !torch.list<vtensor<[],f32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_upsample_nearest
|
||||||
|
func.func @test_upsample_nearest(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
|
// CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float
|
||||||
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
|
// CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float
|
||||||
|
// CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list<float>
|
||||||
|
// CHECK: %[[MODE:.*]] = torch.constant.str "nearest"
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[UPSAMPLE:.*]] = torch.aten.__interpolate.size_list_scale_list %arg0, %[[NONE]], %[[SCALE_LIST:.*]], %[[MODE]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,1,4,6],f32>
|
||||||
|
// CHECK: return %[[UPSAMPLE]] : !torch.vtensor<[1,1,4,6],f32>
|
||||||
|
%0 = torch.operator "onnx.Upsample"(%arg0, %arg1) {torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,4,6],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_upsample_bilinear
|
||||||
|
func.func @test_upsample_bilinear(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
|
// CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float
|
||||||
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
|
// CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float
|
||||||
|
// CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list<float>
|
||||||
|
// CHECK: %[[MODE:.*]] = torch.constant.str "bilinear"
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[UPSAMPLE:.*]] = torch.aten.__interpolate.size_list_scale_list %arg0, %[[NONE]], %[[SCALE_LIST:.*]], %[[MODE]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,1,4,6],f32>
|
||||||
|
// CHECK: return %[[UPSAMPLE]] : !torch.vtensor<[1,1,4,6],f32>
|
||||||
|
%0 = torch.operator "onnx.Upsample"(%arg0, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,4,6],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue