[MLIR][ONNX] Add OnnxToTorch support for Mean, IsInf, IsNaN, PRelu op (#2801)

This commit adds the OnnxToTorch support for Mean, IsInf, IsNaN, and
PRelu ops. All high priority ops were taken so went with these. The non
trivial ones are Mean and IsInf which might require extra review

---------

Co-authored-by: MaheshRavishankar <mravisha@amd.com>
pull/2903/head
saienduri 2024-02-12 23:08:21 -08:00 committed by GitHub
parent b6f4ca512e
commit 9b967f6b5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 157 additions and 0 deletions

View File

@ -1006,4 +1006,107 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, tensor, /*memory_format=*/noneVal);
return success();
});
patterns.onOp(
"Mean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
if (binder.op->getNumOperands() == 1) {
Torch::ValueTensorType resultType;
Value x;
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
return failure();
rewriter.replaceOp(binder.op, x);
return success();
}
Torch::ValueTensorType resultType;
SmallVector<Value> valList;
int64_t numOperands = binder.op->getNumOperands();
Value numOperandsConstant = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), numOperands));
if (binder.tensorOperands(valList, numOperands) ||
binder.tensorResultType(resultType))
return failure();
Value constOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
// Short circuit to binary add
Value curr = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, valList[0], valList[1], constOne);
if (numOperands == 2) {
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
binder.op, resultType, curr, numOperandsConstant);
return success();
}
// When binder.op->getNumOperands() > 2
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
binder.op->getContext());
for (int i = 2; i < numOperands; i++) {
if (i == numOperands - 1) {
curr = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, curr, valList[i], constOne);
} else {
curr = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), baseType, curr, valList[i], constOne);
}
}
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
binder.op, resultType, curr, numOperandsConstant);
return success();
});
patterns.onOp(
"IsInf", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value tensor;
int64_t neg;
int64_t pos;
if (binder.tensorOperand(tensor) ||
binder.s64IntegerAttr(neg, "detect_negative", 1) ||
binder.s64IntegerAttr(pos, "detect_positive", 1) ||
binder.tensorResultType(resultType)) {
return failure();
}
if (neg == 0) {
// replace all negative infs with 0
tensor = rewriter.create<Torch::AtenReluOp>(
binder.getLoc(),
dyn_cast<Torch::ValueTensorType>(tensor.getType()), tensor);
}
if (pos == 0) {
// first use neg op to flip positive inf to negative inf. Then relu to
// replace all positive infs with 0.
Value flip = rewriter.create<Torch::AtenNegOp>(
binder.getLoc(),
dyn_cast<Torch::ValueTensorType>(tensor.getType()), tensor);
tensor = rewriter.create<Torch::AtenReluOp>(
binder.getLoc(), dyn_cast<Torch::ValueTensorType>(flip.getType()),
flip);
}
rewriter.replaceOpWithNewOp<Torch::AtenIsinfOp>(binder.op, resultType,
tensor);
return success();
});
patterns.onOp("IsNaN", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value tensor;
if (binder.tensorOperand(tensor) ||
binder.tensorResultType(resultType)) {
return failure();
}
rewriter.replaceOpWithNewOp<Torch::AtenIsnanOp>(
binder.op, resultType, tensor);
return success();
});
patterns.onOp("PRelu", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value tensor;
Value slope;
if (binder.tensorOperands(tensor, slope) ||
binder.tensorResultType(resultType)) {
return failure();
}
rewriter.replaceOpWithNewOp<Torch::AtenPreluOp>(
binder.op, resultType, tensor, slope);
return success();
});
}

View File

@ -650,3 +650,57 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],
%0 = torch.operator "onnx.Identity"(%arg0) : (!torch.vtensor<[3,4], f32>) -> !torch.vtensor<[3,4], f32>
return %0 : !torch.vtensor<[3,4], f32>
}
// CHECK-LABEL: func.func @test_mean_one_input
func.func @test_mean_one_input(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.Mean"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}
// CHECK-LABEL: func.func @test_mean_two_inputs
func.func @test_mean_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: torch.aten.div.Scalar %0, %int2 : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Mean"(%arg0, %arg1) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}
// CHECK-LABEL: func.func @test_isinf_negative
func.func @test_isinf_negative(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.neg %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32>
// CHECK: torch.aten.relu %0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32>
// CHECK: torch.aten.isinf %1 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1>
%0 = torch.operator "onnx.IsInf"(%arg0) {torch.onnx.detect_positive = 0 : si64} : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1>
return %0 : !torch.vtensor<[6],i1>
}
// CHECK-LABEL: func.func @test_isinf_positive
func.func @test_isinf_positive(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.relu %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32>
// CHECK: torch.aten.isinf %0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1>
%0 = torch.operator "onnx.IsInf"(%arg0) {torch.onnx.detect_negative = 0 : si64} : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1>
return %0 : !torch.vtensor<[6],i1>
}
// CHECK-LABEL: func.func @test_isnan
func.func @test_isnan(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.isnan %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1>
%0 = torch.operator "onnx.IsNaN"(%arg0) : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1>
return %0 : !torch.vtensor<[6],i1>
}
// CHECK-LABEL: func.func @test_prelu_example
func.func @test_prelu_example(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.prelu %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
// CHECK-LABEL: func.func @test_prelu_broadcast
func.func @test_prelu_broadcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.prelu %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}