mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add OnnxToTorch lowering for Einsum op (#3117)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3121/head
parent
84c24e5771
commit
1d6e4c3d77
|
@ -1951,4 +1951,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Einsum", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
SmallVector<Value> tensors;
|
||||
std::string equation;
|
||||
if (binder.tensorOperands(tensors, binder.op->getNumOperands()) ||
|
||||
binder.customOpNameStringAttr(equation, "equation") ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Type listElemType =
|
||||
tensors[0]
|
||||
.getType()
|
||||
.cast<Torch::BaseTensorType>()
|
||||
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
||||
/*optionalDtype=*/nullptr);
|
||||
Type listType = Torch::ListType::get(listElemType);
|
||||
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.op->getLoc(), listType, tensors);
|
||||
Value cstEquation = rewriter.create<Torch::ConstantStrOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::StringType>(),
|
||||
rewriter.getStringAttr(equation));
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenEinsumOp>(
|
||||
binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1977,13 +1977,6 @@ ONNX_XFAIL_SET = {
|
|||
# Failure - onnx_lowering: onnx.Clip
|
||||
"NormalizeModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.Einsum
|
||||
"EinsumStaticContractRhsModule_basic",
|
||||
"EinsumStaticFourDimensionModule_basic",
|
||||
"EinsumStaticModule_basic",
|
||||
"EinsumStaticWithEllipsisSlicingModule_basic",
|
||||
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.MaxPool
|
||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||
|
|
|
@ -1743,3 +1743,51 @@ func.func @test_compress_neg_axis(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.
|
|||
%0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,2,4],f32>
|
||||
return %0 : !torch.vtensor<[2,2,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_einsum_batch_diagonal
|
||||
func.func @test_einsum_batch_diagonal(%arg0: !torch.vtensor<[3,5,5],f64>) -> !torch.vtensor<[3,5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,5,5],f64>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[EQUATION:.*]] = torch.constant.str "...ii ->...i"
|
||||
// CHECK: %[[PATH:.*]] = torch.constant.none
|
||||
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[3,5],f64>
|
||||
%0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "...ii ->...i"} : (!torch.vtensor<[3,5,5],f64>) -> !torch.vtensor<[3,5],f64>
|
||||
return %0 : !torch.vtensor<[3,5],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_einsum_batch_matmul
|
||||
func.func @test_einsum_batch_matmul(%arg0: !torch.vtensor<[5,2,3],f64>, %arg1: !torch.vtensor<[5,3,4],f64>) -> !torch.vtensor<[5,2,4],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[5,2,3],f64>, !torch.vtensor<[5,3,4],f64>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[EQUATION:.*]] = torch.constant.str "bij, bjk -> bik"
|
||||
// CHECK: %[[PATH:.*]] = torch.constant.none
|
||||
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[5,2,4],f64>
|
||||
%0 = torch.operator "onnx.Einsum"(%arg0, %arg1) {torch.onnx.equation = "bij, bjk -> bik"} : (!torch.vtensor<[5,2,3],f64>, !torch.vtensor<[5,3,4],f64>) -> !torch.vtensor<[5,2,4],f64>
|
||||
return %0 : !torch.vtensor<[5,2,4],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_einsum_sum
|
||||
func.func @test_einsum_sum(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,4],f64>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[EQUATION:.*]] = torch.constant.str "ij->i"
|
||||
// CHECK: %[[PATH:.*]] = torch.constant.none
|
||||
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[3],f64>
|
||||
%0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "ij->i"} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3],f64>
|
||||
return %0 : !torch.vtensor<[3],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_einsum_transpose
|
||||
func.func @test_einsum_transpose(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[4,3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,4],f64>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[EQUATION:.*]] = torch.constant.str "ij->ji"
|
||||
// CHECK: %[[PATH:.*]] = torch.constant.none
|
||||
// CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[4,3],f64>
|
||||
%0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "ij->ji"} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[4,3],f64>
|
||||
return %0 : !torch.vtensor<[4,3],f64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue