mirror of https://github.com/llvm/torch-mlir
[onnx] `onnx.CenterCropPad` used an incorrect type for toScalar (#3605)
To scalar should have a rank-0 tensor type not rank-1 with length 1. Changing allows proper compilation.pull/3617/head
parent
59a4c6fda4
commit
6c33ab024e
|
@ -759,6 +759,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(2));
|
||||
auto scalarTensorType = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>{}, rewriter.getIntegerType(64, /*signed*/ 1));
|
||||
auto selectTensorType = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>{1}, rewriter.getIntegerType(64, /*signed*/ 1));
|
||||
|
||||
int64_t lastChangeDim = 0;
|
||||
|
@ -790,7 +792,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
Value kTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
||||
binder.getLoc(), scalarTensorType, k);
|
||||
Value sel = rewriter.create<Torch::AtenIndexSelectOp>(
|
||||
binder.getLoc(), scalarTensorType, shape, cstZero, kTensor);
|
||||
binder.getLoc(), selectTensorType, shape, cstZero, kTensor);
|
||||
Value outputDimSize = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), sel);
|
||||
Value inputDimSize = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
|
|
|
@ -2526,28 +2526,28 @@ func.func @test_center_crop_pad_crop_and_pad(%arg0: !torch.vtensor<[20,8,3],f32>
|
|||
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
|
||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0_2]] : !torch.vtensor<[20,8,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[20,8,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,3],f32>
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_2]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[C0_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_3]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
|
||||
|
@ -2571,28 +2571,28 @@ func.func @test_center_crop_pad_crop_axes_chw(%arg0: !torch.vtensor<[3,20,8],f32
|
|||
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_0]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_0]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C1_1]] : !torch.vtensor<[3,20,8],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C1_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[3,20,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
|
||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_2]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_2]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_1]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_1]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int
|
||||
|
@ -2616,28 +2616,28 @@ func.func @test_center_crop_pad_crop_negative_axes_hwc(%arg0: !torch.vtensor<[20
|
|||
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
|
||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0_2]] : !torch.vtensor<[20,8,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[20,8,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,3],f32>
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_2]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[C0_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_3]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
|
||||
|
|
Loading…
Reference in New Issue