mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for Dropout and Elu op
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/2720/head
parent
07d0645f64
commit
35e8f86792
|
@ -904,6 +904,62 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, resultType, lhs, rhs);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Dropout", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Location loc = binder.getLoc();
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t numOperands = binder.op->getNumOperands();
|
||||
SmallVector<Value> operands;
|
||||
int64_t seed;
|
||||
if (binder.tensorOperands(operands, numOperands) ||
|
||||
binder.s64IntegerAttr(seed, "seed", 0) ||
|
||||
binder.tensorResultTypeAtIndex(resultType, 0))
|
||||
return failure();
|
||||
|
||||
// Global Seed value is 0.
|
||||
if (seed != 0) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"expected seed value to be 0");
|
||||
}
|
||||
|
||||
Value ratio, trainingMode;
|
||||
if (numOperands == 3) {
|
||||
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
|
||||
Value trainingModeScalar =
|
||||
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]);
|
||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
trainingMode = rewriter.create<Torch::AtenEqIntOp>(
|
||||
loc, trainingModeScalar, cstOne);
|
||||
} else if (numOperands == 2) {
|
||||
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
|
||||
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
} else {
|
||||
ratio = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(0.5));
|
||||
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
}
|
||||
|
||||
Value dropout = rewriter.create<Torch::AtenDropoutOp>(
|
||||
loc, resultType, /*input=*/operands[0], ratio, trainingMode);
|
||||
|
||||
if (binder.op->getNumResults() == 1) {
|
||||
rewriter.replaceOp(binder.op, dropout);
|
||||
return success();
|
||||
}
|
||||
Torch::ValueTensorType maskType;
|
||||
if (binder.tensorResultTypeAtIndex(maskType, 1))
|
||||
return failure();
|
||||
Value dtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(
|
||||
(int64_t)torch_upstream::ScalarType::Bool));
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value mask = rewriter.create<Torch::AtenOnesLikeOp>(
|
||||
loc, maskType, operands[0], dtype, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
|
||||
rewriter.replaceOp(binder.op, {dropout, mask});
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Equal", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
@ -916,6 +972,25 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, resultType, lhs, rhs);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Elu", 6,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Location loc = binder.getLoc();
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input;
|
||||
float alpha;
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.f32FloatAttr(alpha, "alpha") ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Value cstAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(alpha));
|
||||
Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(1.0));
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
|
||||
binder.op, resultType, input, cstAlpha, /*scale=*/cstOne,
|
||||
/*input_scale=*/cstOne);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Erf", 13,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -740,3 +740,61 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a
|
|||
%0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -3 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32>
|
||||
return %0 : !torch.vtensor<[4,2,2],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_dropout
|
||||
func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32
|
||||
%0 = torch.operator "onnx.Dropout"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
|
||||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_dropout_default
|
||||
func.func @test_dropout_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Dropout"(%arg0) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_dropout_default_mask
|
||||
func.func @test_dropout_default_mask(%arg0: !torch.vtensor<[3,4,5],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: torch.aten.ones_like %arg0, %int11, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],i1>
|
||||
%0:2 = torch.operator "onnx.Dropout"(%arg0) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>)
|
||||
return %0#0, %0#1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_dropout_default_mask_ratio
|
||||
func.func @test_dropout_default_mask_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.dropout %arg0, %0, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: torch.aten.ones_like %arg0, %int11, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],i1>
|
||||
%0:2 = torch.operator "onnx.Dropout"(%arg0, %arg1) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>)
|
||||
return %0#0, %0#1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_dropout_default_ratio
|
||||
func.func @test_dropout_default_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.dropout %arg0, %0, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Dropout"(%arg0, %arg1) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_training_dropout_zero_ratio
|
||||
func.func @test_training_dropout_zero_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],i1>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.dropout %arg0, %0, %2 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Dropout"(%arg0, %arg1, %arg2) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],i1>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_elu_default
|
||||
func.func @test_elu_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.elu %arg0, %float0.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Elu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_elu_example
|
||||
func.func @test_elu_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.elu %arg0, %float2.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3],f32>
|
||||
%0 = torch.operator "onnx.Elu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
|
||||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue