mirror of https://github.com/llvm/torch-mlir
[ONNX] Fix Onnx.Selu lowering and canonicalizer for IntImplicit op (#3221)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3266/head
parent
b2185195e8
commit
b1e2241479
|
@ -847,15 +847,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
patterns.onOp(
|
||||
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0
|
||||
Torch::ValueTensorType resultType;
|
||||
float alpha, gamma;
|
||||
Value operand;
|
||||
// Refer https://onnx.ai/onnx/operators/onnx__Selu.html for the default
|
||||
// alpha and gamma values.
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.f32FloatAttr(alpha, "alpha") ||
|
||||
binder.f32FloatAttr(gamma, "gamma") ||
|
||||
binder.f32FloatAttr(alpha, "alpha", 1.67326) ||
|
||||
binder.f32FloatAttr(gamma, "gamma", 1.0507) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
Torch::ValueTensorType inputType =
|
||||
operand.getType().cast<Torch::ValueTensorType>();
|
||||
|
||||
Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
|
||||
|
@ -864,12 +870,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));
|
||||
|
||||
Value vInputScale = rewriter.create<Torch::ConstantFloatOp>(
|
||||
Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
|
||||
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value zeroTensor = rewriter.create<Torch::AtenZerosLikeOp>(
|
||||
binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone,
|
||||
cstNone, cstNone);
|
||||
Value exp = rewriter.create<Torch::AtenExpOp>(binder.getLoc(),
|
||||
resultType, operand);
|
||||
Value expMulAlpha = rewriter.create<Torch::AtenMulScalarOp>(
|
||||
binder.getLoc(), resultType, exp, vAlpha);
|
||||
Value expMulAlphaSubAlpha = rewriter.create<Torch::AtenSubScalarOp>(
|
||||
binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne);
|
||||
Value neg = rewriter.create<Torch::AtenMulScalarOp>(
|
||||
binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale);
|
||||
Value pos = rewriter.create<Torch::AtenMulScalarOp>(
|
||||
binder.getLoc(), resultType, operand, vScale);
|
||||
Type compareType = inputType.getWithSizesAndDtype(
|
||||
inputType.getOptionalSizes(), rewriter.getI1Type());
|
||||
Value xLessThanZero = rewriter.create<Torch::AtenLtTensorOp>(
|
||||
binder.getLoc(), compareType, operand, zeroTensor);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
|
||||
binder.op, resultType, xLessThanZero, neg, pos);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("ReduceL1", 1,
|
||||
|
|
|
@ -140,7 +140,7 @@ static Value getScalarIntValue(Value input, Location loc,
|
|||
return nullptr;
|
||||
|
||||
Type inputDtype = inputTensorType.getOptionalDtype();
|
||||
if (!inputDtype || !inputDtype.isInteger(64))
|
||||
if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1)))
|
||||
return nullptr;
|
||||
|
||||
std::optional<unsigned> inputRank = getTensorRank(input);
|
||||
|
@ -148,10 +148,19 @@ static Value getScalarIntValue(Value input, Location loc,
|
|||
return nullptr;
|
||||
|
||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
||||
.getSplatValue<int64_t>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
if (inputDtype.isInteger(64)) {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseIntElementsAttr>()
|
||||
.getSplatValue<int64_t>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
} else {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseIntElementsAttr>()
|
||||
.getSplatValue<bool>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
}
|
||||
} else if (auto primNumToTensorScalarOp =
|
||||
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
||||
return primNumToTensorScalarOp.getA();
|
||||
|
|
|
@ -2124,7 +2124,6 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
||||
"ElementwiseLog10IntModule_basic",
|
||||
"ElementwiseLog2IntModule_basic",
|
||||
"ElementwiseSeluModule_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
"FlipNegativeIndexModule_basic",
|
||||
"HardsigmoidModule_basic",
|
||||
|
@ -2637,8 +2636,6 @@ ONNX_XFAIL_SET = {
|
|||
"CopyWithDifferentDTypesModule_basic",
|
||||
"CosineSimilarityStaticBroadcastModule_basic",
|
||||
"CumsumInputDtypeInt32Module_basic",
|
||||
"DropoutTrainModule_basic",
|
||||
"DropoutTrainStaticShapeModule_basic",
|
||||
"ElementwiseAcosIntModule_basic",
|
||||
"ElementwiseAsinIntModule_basic",
|
||||
"ElementwiseAtanTensorIntModule_basic",
|
||||
|
|
|
@ -582,10 +582,18 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to
|
|||
|
||||
// CHECK-LABEL: func.func @test_selu
|
||||
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
|
||||
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1
|
||||
// CHECK-DAG: %[[F2:.+]] = torch.constant.float 2
|
||||
// CHECK-DAG: %[[F3:.+]] = torch.constant.float 3
|
||||
// CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]]
|
||||
// CHECK: %[[F2:.+]] = torch.constant.float 2.000000e+00
|
||||
// CHECK: %[[F3:.+]] = torch.constant.float 3.000000e+00
|
||||
// CHECK: %[[F1:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[ZEROS:.+]] = torch.aten.zeros_like %arg0, %none, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[EXP:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[EXP]], %[[F2]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[MUL]], %[[F2]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[MUL_1:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[MUL_2:.+]] = torch.aten.mul.Scalar %arg0, %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.Tensor %arg0, %[[ZEROS]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1>
|
||||
// CHECK: torch.aten.where.self %[[LT]], %[[MUL_1]], %[[MUL_2]] : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue