mirror of https://github.com/llvm/torch-mlir
[ONNX] LogSoftmax to Torch (#3024)
This PR adds support for onnx.LogSoftmax both for old versions (<13, with axis >=0), and new versions (13).pull/3049/head
parent
50635dd509
commit
6aa481c204
|
@ -195,6 +195,100 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.op, resultType, operand);
|
binder.op, resultType, operand);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"LogSoftmax", 13,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Value input;
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
if (binder.tensorOperand(input) || binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
int64_t axis;
|
||||||
|
if (binder.s64IntegerAttr(axis, "axis", -1))
|
||||||
|
return rewriter.notifyMatchFailure(binder.op, "axis bind failure");
|
||||||
|
Value axisConst = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenLogSoftmaxIntOp>(
|
||||||
|
binder.op, resultType, input, axisConst, none);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"LogSoftmax", 1,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Value input;
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
if (binder.tensorOperand(input) || binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t axis;
|
||||||
|
if (binder.s64IntegerAttr(axis, "axis", 1))
|
||||||
|
return rewriter.notifyMatchFailure(binder.op, "axis bind failure");
|
||||||
|
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
|
||||||
|
if (!maybeRank)
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"Unsupported: unranked tensor");
|
||||||
|
int64_t rank = *maybeRank;
|
||||||
|
// if negative axis is provided, then flip it to a positive axis
|
||||||
|
if (axis < 0) {
|
||||||
|
axis = rank + axis;
|
||||||
|
}
|
||||||
|
// need input type and sizes to flatten/unflatten later.
|
||||||
|
auto inputTy = input.getType().cast<Torch::ValueTensorType>();
|
||||||
|
if (!inputTy || !inputTy.hasSizes())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "failed to get input type or sizes");
|
||||||
|
|
||||||
|
Value axisConst = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
Value cstEnd = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1));
|
||||||
|
|
||||||
|
// The old version of LogSoftmax flattens post-axis dims, performs
|
||||||
|
// LogSoftmax on the flattened dim, then unflattens back to the original
|
||||||
|
// shape.
|
||||||
|
|
||||||
|
// this section gets some size information necessary for
|
||||||
|
// flattening/unflattening
|
||||||
|
if (!inputTy || !inputTy.hasSizes())
|
||||||
|
return failure();
|
||||||
|
llvm::ArrayRef<int64_t> allDims(inputTy.getSizes());
|
||||||
|
llvm::ArrayRef<int64_t> rightDims(allDims.begin() + axis,
|
||||||
|
allDims.end());
|
||||||
|
llvm::SmallVector<int64_t> leftDims(allDims.begin(),
|
||||||
|
allDims.begin() + axis);
|
||||||
|
int64_t prodRightSizes = 1;
|
||||||
|
llvm::SmallVector<Value> rightDimConsts;
|
||||||
|
for (int64_t n : rightDims) {
|
||||||
|
rightDimConsts.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(n)));
|
||||||
|
if (n == Torch::kUnknownSize) {
|
||||||
|
prodRightSizes = -1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
prodRightSizes *= n;
|
||||||
|
}
|
||||||
|
leftDims.push_back(prodRightSizes);
|
||||||
|
// the following list will be used to unflatten the right side
|
||||||
|
Value rightDimsPrimList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
rewriter.getType<Torch::ListType>(
|
||||||
|
rewriter.getType<Torch::IntType>()),
|
||||||
|
rightDimConsts);
|
||||||
|
auto flatRightTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
leftDims, inputTy.getOptionalDtype());
|
||||||
|
// flatten input
|
||||||
|
Value inputFlatRight = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||||
|
binder.getLoc(), flatRightTy, input, axisConst, cstEnd);
|
||||||
|
// compute lsm over flattened index
|
||||||
|
Value outputFlatRight = rewriter.create<Torch::AtenLogSoftmaxIntOp>(
|
||||||
|
binder.getLoc(), flatRightTy, inputFlatRight, axisConst, none);
|
||||||
|
// unflatten
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
|
||||||
|
binder.op, resultType, outputFlatRight, axisConst,
|
||||||
|
rightDimsPrimList);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp("MatMul", 13,
|
patterns.onOp("MatMul", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
|
|
@ -1906,11 +1906,6 @@ ONNX_XFAIL_SET = {
|
||||||
"HardswishRandomModule_basic",
|
"HardswishRandomModule_basic",
|
||||||
"MobilenetV3Module_basic",
|
"MobilenetV3Module_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.LogSoftmax
|
|
||||||
"LogSoftmaxIntModule_basic",
|
|
||||||
"_LogSoftmaxModuleStable_basic",
|
|
||||||
"_LogSoftmaxModule_basic",
|
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.MaxPool
|
# Failure - onnx_lowering: onnx.MaxPool
|
||||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||||
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
|
|
|
@ -748,6 +748,66 @@ func.func @test_mod_int64_no_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_log_softmax_default_axis
|
||||||
|
func.func @test_log_softmax_default_axis(%arg0: !torch.vtensor<[1,3],f32>) -> !torch.vtensor<[1,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[CIM1:.*]] = torch.constant.int -1
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[LSM:.*]] = torch.aten.log_softmax.int %arg0, %[[CIM1]], %[[NONE]] : !torch.vtensor<[1,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[1,3],f32>
|
||||||
|
// CHECK: return %[[LSM]] : !torch.vtensor<[1,3],f32>
|
||||||
|
%0 = torch.operator "onnx.LogSoftmax"(%arg0) : (!torch.vtensor<[1,3],f32>) -> !torch.vtensor<[1,3],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_log_softmax_axis_2
|
||||||
|
func.func @test_log_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[CI2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[LSM:.*]] = torch.aten.log_softmax.int %arg0, %[[CI2]], %[[NONE]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: return %[[LSM]] : !torch.vtensor<[3,4,5],f32>
|
||||||
|
%0 = torch.operator "onnx.LogSoftmax"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_logsoftmax_old_axis_1_dynamic_dim
|
||||||
|
func.func @test_logsoftmax_old_axis_1_dynamic_dim(%arg0: !torch.vtensor<[3,4,?],f32>) -> !torch.vtensor<[3,4,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[CI1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[CI2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[CI4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[CIM1:.*]] = torch.constant.int -1
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CI4]], %[[CIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLAT_IN:.*]] = torch.aten.flatten.using_ints %arg0, %[[CI1]], %[[CI2]] : !torch.vtensor<[3,4,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,?],f32>
|
||||||
|
// CHECK: %[[LSM:.*]] = torch.aten.log_softmax.int %[[FLAT_IN]], %[[CI1]], %[[NONE]] : !torch.vtensor<[3,?],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,?],f32>
|
||||||
|
// CHECK: %[[UNFLAT:.*]] = torch.aten.unflatten.int %[[LSM]], %[[CI1]], %[[LIST]] : !torch.vtensor<[3,?],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[3,4,?],f32>
|
||||||
|
// CHECK: return %[[UNFLAT]] : !torch.vtensor<[3,4,?],f32>
|
||||||
|
%0 = torch.operator "onnx.LogSoftmax"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4,?],f32>) -> !torch.vtensor<[3,4,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_logsoftmax_old_axis_1_static
|
||||||
|
func.func @test_logsoftmax_old_axis_1_static(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[CI1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[CI2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[CI4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[CI5:.*]] = torch.constant.int 5
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CI4]], %[[CI5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLAT_IN:.*]] = torch.aten.flatten.using_ints %arg0, %[[CI1]], %[[CI2]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,20],f32>
|
||||||
|
// CHECK: %[[LSM:.*]] = torch.aten.log_softmax.int %[[FLAT_IN]], %[[CI1]], %[[NONE]] : !torch.vtensor<[3,20],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,20],f32>
|
||||||
|
// CHECK: %[[UNFLAT:.*]] = torch.aten.unflatten.int %[[LSM]], %[[CI1]], %[[LIST]] : !torch.vtensor<[3,20],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: return %[[UNFLAT]] : !torch.vtensor<[3,4,5],f32>
|
||||||
|
%0 = torch.operator "onnx.LogSoftmax"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_neg
|
// CHECK-LABEL: func.func @test_neg
|
||||||
func.func @test_neg(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_neg(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
// CHECK: torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
|
Loading…
Reference in New Issue