mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for Mean, IsInf, IsNaN, PRelu op (#2801)
This commit adds the OnnxToTorch support for Mean, IsInf, IsNaN, and PRelu ops. All high priority ops were taken so went with these. The non trivial ones are Mean and IsInf which might require extra review --------- Co-authored-by: MaheshRavishankar <mravisha@amd.com>pull/2903/head
parent
b6f4ca512e
commit
9b967f6b5a
|
@ -1006,4 +1006,107 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, resultType, tensor, /*memory_format=*/noneVal);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Mean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
if (binder.op->getNumOperands() == 1) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value x;
|
||||
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
rewriter.replaceOp(binder.op, x);
|
||||
return success();
|
||||
}
|
||||
Torch::ValueTensorType resultType;
|
||||
SmallVector<Value> valList;
|
||||
int64_t numOperands = binder.op->getNumOperands();
|
||||
Value numOperandsConstant = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), numOperands));
|
||||
if (binder.tensorOperands(valList, numOperands) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
||||
// Short circuit to binary add
|
||||
Value curr = rewriter.create<Torch::AtenAddTensorOp>(
|
||||
binder.getLoc(), resultType, valList[0], valList[1], constOne);
|
||||
if (numOperands == 2) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
|
||||
binder.op, resultType, curr, numOperandsConstant);
|
||||
return success();
|
||||
}
|
||||
// When binder.op->getNumOperands() > 2
|
||||
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
|
||||
binder.op->getContext());
|
||||
for (int i = 2; i < numOperands; i++) {
|
||||
if (i == numOperands - 1) {
|
||||
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
||||
binder.getLoc(), resultType, curr, valList[i], constOne);
|
||||
} else {
|
||||
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
||||
binder.getLoc(), baseType, curr, valList[i], constOne);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
|
||||
binder.op, resultType, curr, numOperandsConstant);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"IsInf", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value tensor;
|
||||
int64_t neg;
|
||||
int64_t pos;
|
||||
if (binder.tensorOperand(tensor) ||
|
||||
binder.s64IntegerAttr(neg, "detect_negative", 1) ||
|
||||
binder.s64IntegerAttr(pos, "detect_positive", 1) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
if (neg == 0) {
|
||||
// replace all negative infs with 0
|
||||
tensor = rewriter.create<Torch::AtenReluOp>(
|
||||
binder.getLoc(),
|
||||
dyn_cast<Torch::ValueTensorType>(tensor.getType()), tensor);
|
||||
}
|
||||
if (pos == 0) {
|
||||
// first use neg op to flip positive inf to negative inf. Then relu to
|
||||
// replace all positive infs with 0.
|
||||
Value flip = rewriter.create<Torch::AtenNegOp>(
|
||||
binder.getLoc(),
|
||||
dyn_cast<Torch::ValueTensorType>(tensor.getType()), tensor);
|
||||
tensor = rewriter.create<Torch::AtenReluOp>(
|
||||
binder.getLoc(), dyn_cast<Torch::ValueTensorType>(flip.getType()),
|
||||
flip);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenIsinfOp>(binder.op, resultType,
|
||||
tensor);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("IsNaN", 9,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value tensor;
|
||||
if (binder.tensorOperand(tensor) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenIsnanOp>(
|
||||
binder.op, resultType, tensor);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("PRelu", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value tensor;
|
||||
Value slope;
|
||||
if (binder.tensorOperands(tensor, slope) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenPreluOp>(
|
||||
binder.op, resultType, tensor, slope);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -650,3 +650,57 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],
|
|||
%0 = torch.operator "onnx.Identity"(%arg0) : (!torch.vtensor<[3,4], f32>) -> !torch.vtensor<[3,4], f32>
|
||||
return %0 : !torch.vtensor<[3,4], f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_mean_one_input
|
||||
func.func @test_mean_one_input(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.Mean"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
|
||||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_mean_two_inputs
|
||||
func.func @test_mean_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
|
||||
// CHECK: torch.aten.div.Scalar %0, %int2 : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
|
||||
%0 = torch.operator "onnx.Mean"(%arg0, %arg1) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
|
||||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_isinf_negative
|
||||
func.func @test_isinf_negative(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.neg %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32>
|
||||
// CHECK: torch.aten.relu %0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32>
|
||||
// CHECK: torch.aten.isinf %1 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1>
|
||||
%0 = torch.operator "onnx.IsInf"(%arg0) {torch.onnx.detect_positive = 0 : si64} : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1>
|
||||
return %0 : !torch.vtensor<[6],i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_isinf_positive
|
||||
func.func @test_isinf_positive(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.relu %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32>
|
||||
// CHECK: torch.aten.isinf %0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1>
|
||||
%0 = torch.operator "onnx.IsInf"(%arg0) {torch.onnx.detect_negative = 0 : si64} : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1>
|
||||
return %0 : !torch.vtensor<[6],i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_isnan
|
||||
func.func @test_isnan(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.isnan %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1>
|
||||
%0 = torch.operator "onnx.IsNaN"(%arg0) : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1>
|
||||
return %0 : !torch.vtensor<[6],i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_prelu_example
|
||||
func.func @test_prelu_example(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.prelu %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_prelu_broadcast
|
||||
func.func @test_prelu_broadcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.prelu %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue