mirror of https://github.com/llvm/torch-mlir
[onnx] Support for `onnx.EyeLike` via torch lowering (#2994)
parent
9ac90ec7b2
commit
83cba8c696
|
@ -1741,6 +1741,95 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, resultType, data, dimValueList);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"EyeLike", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
int64_t dtypeIntOnnx, diagonalIndex;
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
|
||||
binder.s64IntegerAttr(diagonalIndex, "k", 0) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
|
||||
SmallVector<int64_t> shape(operandTy.getSizes());
|
||||
for (unsigned i = 0; i < shape.size(); i++) {
|
||||
if (shape[i] == ShapedType::kDynamic)
|
||||
shape[i] = Torch::kUnknownSize;
|
||||
}
|
||||
|
||||
Value cst0 = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
Value nVal = rewriter.create<Torch::AtenSizeIntOp>(binder.getLoc(),
|
||||
operand, cst0);
|
||||
Value mVal = rewriter.create<Torch::AtenSizeIntOp>(binder.getLoc(),
|
||||
operand, cst1);
|
||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
int64_t dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||
Value dtypeVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch));
|
||||
|
||||
// diagonalIndex = 0 populates the main diagonal
|
||||
// diagonalIndex > 0 populates an upper diagonal
|
||||
// diagonalIndex < 0 populates a lower diagonal
|
||||
if (diagonalIndex == 0) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenEyeMOp>(
|
||||
binder.op, resultType, nVal, mVal, dtypeVal, noneVal, noneVal,
|
||||
noneVal);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value diagVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(std::abs(diagonalIndex)));
|
||||
Value newN, newM, dimVal, startVal;
|
||||
// get shapes of main diag eye op and zeros op
|
||||
if (diagonalIndex > 0) {
|
||||
newN = nVal;
|
||||
newM = rewriter.create<Torch::AtenSubIntOp>(binder.getLoc(), mVal,
|
||||
diagVal);
|
||||
if (shape[1] != Torch::kUnknownSize) {
|
||||
shape[1] -= diagonalIndex;
|
||||
}
|
||||
dimVal = cst1;
|
||||
startVal = mVal;
|
||||
} else {
|
||||
newN = rewriter.create<Torch::AtenSubIntOp>(binder.getLoc(), nVal,
|
||||
diagVal);
|
||||
newM = mVal;
|
||||
if (shape[0] != Torch::kUnknownSize) {
|
||||
shape[0] += diagonalIndex;
|
||||
}
|
||||
dimVal = cst0;
|
||||
startVal = nVal;
|
||||
}
|
||||
|
||||
// create main diag eye op
|
||||
auto eyeResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||
shape, resultType.getOptionalDtype());
|
||||
Value eyeOp = rewriter.create<Torch::AtenEyeMOp>(
|
||||
binder.getLoc(), eyeResultType, newN, newM, dtypeVal, noneVal,
|
||||
noneVal, noneVal);
|
||||
// create zeros op
|
||||
SmallVector<Value> zerosShapeValues = {nVal, mVal};
|
||||
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
zerosShapeValues);
|
||||
Value zerosOp = rewriter.create<Torch::AtenZerosOp>(
|
||||
binder.getLoc(), resultType, zerosShapeList, dtypeVal, noneVal,
|
||||
noneVal, noneVal);
|
||||
|
||||
// embeds the values of the eye matrix into zeros
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSliceScatterOp>(
|
||||
binder.op, resultType, zerosOp, eyeOp, dimVal,
|
||||
/*start=*/diagVal, /*end=*/startVal, /*step=*/cst1);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Flatten", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// Flatten means to partition the input tensor's dimensions
|
||||
|
|
|
@ -1792,3 +1792,73 @@ func.func @test_einsum_transpose(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vte
|
|||
%0 = torch.operator "onnx.Einsum"(%arg0) {torch.onnx.equation = "ij->ji"} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[4,3],f64>
|
||||
return %0 : !torch.vtensor<[4,3],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_eyelike_m
|
||||
func.func @test_eyelike_m(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 6
|
||||
// CHECK: torch.aten.eye.m %[[DIM0]], %[[DIM1]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
|
||||
%0 = torch.operator "onnx.EyeLike"(%arg0) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_eyelike_int
|
||||
func.func @test_eyelike_int(%arg0: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[3,3], si64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[3,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[3,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 4
|
||||
// CHECK: torch.aten.eye.m %[[DIM0]], %[[DIM1]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,3],si64>
|
||||
%0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.dtype = 7 : si64} : (!torch.vtensor<[3,3],f32>) -> !torch.vtensor<[3,3],si64>
|
||||
return %0 : !torch.vtensor<[3,3],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_eyelike_diagonal
|
||||
func.func @test_eyelike_diagonal(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[DIAG:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[NEW_DIM:.*]] = torch.aten.sub.int %[[DIM1]], %[[DIAG]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[EYE:.*]] = torch.aten.eye.m %[[DIM0]], %[[NEW_DIM]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,3],f32>
|
||||
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[EYE]], %[[INT1]], %[[DIAG]], %[[DIM1]], %[[INT1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
|
||||
%0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = 1 : si64} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_eyelike_dynamic
|
||||
func.func @test_eyelike_dynamic(%arg0: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[3,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[3,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[DIAG:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[NEW_DIM:.*]] = torch.aten.sub.int %[[DIM0]], %[[DIAG]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[EYE:.*]] = torch.aten.eye.m %[[NEW_DIM]], %[[DIM1]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,?],f32>
|
||||
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,?],f32>
|
||||
// CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[EYE]], %[[INT0]], %[[DIAG]], %[[DIM0]], %[[INT1]] : !torch.vtensor<[3,?],f32>, !torch.vtensor<[2,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?],f32>
|
||||
%0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32>
|
||||
return %0 : !torch.vtensor<[3,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue