[MLIR][TORCH] Add OnnxToTorch lowering for Einsum op (#3117)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3121/head
Vivek Khandelwal 2024-04-08 22:38:01 +05:30 committed by GitHub
parent 84c24e5771
commit 1d6e4c3d77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 74 additions and 7 deletions

View File

@ -1951,4 +1951,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
return success();
});
patterns.onOp(
"Einsum", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
SmallVector<Value> tensors;
std::string equation;
if (binder.tensorOperands(tensors, binder.op->getNumOperands()) ||
binder.customOpNameStringAttr(equation, "equation") ||
binder.tensorResultType(resultType))
return failure();
Type listElemType =
tensors[0]
.getType()
.cast<Torch::BaseTensorType>()
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.op->getLoc(), listType, tensors);
Value cstEquation = rewriter.create<Torch::ConstantStrOp>(
binder.getLoc(), rewriter.getType<Torch::StringType>(),
rewriter.getStringAttr(equation));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
rewriter.replaceOpWithNewOp<Torch::AtenEinsumOp>(
binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone);
return success();
});
}

View File

@ -1977,13 +1977,6 @@ ONNX_XFAIL_SET = {
# Failure - onnx_lowering: onnx.Clip
"NormalizeModule_basic",
# Failure - onnx_lowering: onnx.Einsum
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
# Failure - onnx_lowering: onnx.MaxPool
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",

View File

@ -1743,3 +1743,51 @@ func.func @test_compress_neg_axis(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.
%0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,2,4],f32>
return %0 : !torch.vtensor<[2,2,4],f32>
}
// -----
// CHECK-LABEL: func.func @test_einsum_batch_diagonal
func.func @test_einsum_batch_diagonal(%arg0: !torch.vtensor<[3,5,5],f64>) -> !torch.vtensor<[3,5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,5,5],f64>) -> !torch.list<vtensor>
// CHECK: %[[EQUATION:.*]] = torch.constant.str "...ii ->...i"
// CHECK: %[[PATH:.*]] = torch.constant.none
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[3,5],f64>
%0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "...ii ->...i"} : (!torch.vtensor<[3,5,5],f64>) -> !torch.vtensor<[3,5],f64>
return %0 : !torch.vtensor<[3,5],f64>
}
// -----
// CHECK-LABEL: func.func @test_einsum_batch_matmul
func.func @test_einsum_batch_matmul(%arg0: !torch.vtensor<[5,2,3],f64>, %arg1: !torch.vtensor<[5,3,4],f64>) -> !torch.vtensor<[5,2,4],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[5,2,3],f64>, !torch.vtensor<[5,3,4],f64>) -> !torch.list<vtensor>
// CHECK: %[[EQUATION:.*]] = torch.constant.str "bij, bjk -> bik"
// CHECK: %[[PATH:.*]] = torch.constant.none
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[5,2,4],f64>
%0 = torch.operator "onnx.Einsum"(%arg0, %arg1) {torch.onnx.equation = "bij, bjk -> bik"} : (!torch.vtensor<[5,2,3],f64>, !torch.vtensor<[5,3,4],f64>) -> !torch.vtensor<[5,2,4],f64>
return %0 : !torch.vtensor<[5,2,4],f64>
}
// -----
// CHECK-LABEL: func.func @test_einsum_sum
func.func @test_einsum_sum(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,4],f64>) -> !torch.list<vtensor>
// CHECK: %[[EQUATION:.*]] = torch.constant.str "ij->i"
// CHECK: %[[PATH:.*]] = torch.constant.none
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[3],f64>
%0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "ij->i"} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3],f64>
return %0 : !torch.vtensor<[3],f64>
}
// -----
// CHECK-LABEL: func.func @test_einsum_transpose
func.func @test_einsum_transpose(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[4,3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,4],f64>) -> !torch.list<vtensor>
// CHECK: %[[EQUATION:.*]] = torch.constant.str "ij->ji"
// CHECK: %[[PATH:.*]] = torch.constant.none
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[4,3],f64>
%0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "ij->ji"} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[4,3],f64>
return %0 : !torch.vtensor<[4,3],f64>
}