diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 54cfb3e2a..bed22b084 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1472,4 +1472,85 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "ConstantOfShape", 20, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value shape; + if (binder.tensorOperand(shape) || binder.tensorResultType(resultType)) + return failure(); + + // convert shape tensor to list of ints + auto shapeSizes = + dyn_cast(shape.getType()).getSizes(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Torch::BaseTensorType shapeType = + shape.getType().cast(); + Type selectResultType = shapeType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + for (int i = 0; i < shapeSizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, shape, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + dimList.push_back(dim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value noneVal = rewriter.create(binder.getLoc()); + + // Get fill_value if it is present. + // Assumption : resultDType and value attr type match. + Value value_const; + auto attr = binder.op->getAttr("torch.onnx.value"); + auto resultDType = resultType.getDtype(); + + // Extract the fill value and dtype + // ONNX requires value attr to be a tensor + if (!attr) { + attr = DenseElementsAttr::get( + resultType.toBuiltinTensor().clone(resultDType), + rewriter.getFloatAttr(resultDType, 0.0)); + } + if (!isa(attr)) { + return rewriter.notifyMatchFailure( + binder.op, "`value` attr needs to be a tensor."); + } + + auto denseAttr = attr.cast(); + auto denseAttrEleType = denseAttr.getElementType(); + if (!isa(denseAttrEleType)) { + return rewriter.notifyMatchFailure( + binder.op, + "`value` attr tensor only supports types int and float for now."); + } + + // Create constant op for value + if (denseAttrEleType.isa()) { + int64_t intVal = denseAttr.getSplatValue().getSInt(); + value_const = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(intVal)); + } + if (denseAttrEleType.isa()) { + float floatVal = + denseAttr.getSplatValue().getValue().convertToFloat(); + value_const = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(floatVal)); + } + + rewriter.replaceOpWithNewOp( + binder.op, resultType, dimValueList, value_const, /*dtype=*/noneVal, + /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 493cdc983..2c06567bd 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1413,3 +1413,75 @@ func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vten %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> return %0 : !torch.vtensor<[2,1],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_constant_of_shape_dense_float_default +func.func @test_constant_of_shape_dense_float_default() -> !torch.vtensor<[2,3,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[ATEN_FULL:.*]] = torch.aten.full %[[DIM_LIST]], %[[FILL_VAL]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32> + %cst = torch.vtensor.literal(dense<[2,3,4]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = "torch.operator"(%cst) <{name = "onnx.ConstantOfShape"}> : (!torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,4], f32> + return %0 : !torch.vtensor<[2,3,4], f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_constant_of_shape_dense_float_cst +func.func @test_constant_of_shape_dense_float_cst() -> !torch.vtensor<[2,3,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 3.4000000953674316 + // CHECK: %[[ATEN_FULL:.*]] = torch.aten.full %[[DIM_LIST]], %[[FILL_VAL]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32> + %cst = torch.vtensor.literal(dense<[2,3,4]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = "torch.operator"(%cst) <{name = "onnx.ConstantOfShape"}> {torch.onnx.value = dense<3.4> : tensor<1xf32>}: (!torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,4], f32> + return %0 : !torch.vtensor<[2,3,4], f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_constant_of_shape_dense_int_cst +func.func @test_constant_of_shape_dense_int_cst() -> !torch.vtensor<[2,3,4], si64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FILL_VAL:.*]] = torch.constant.int 3 + // CHECK: %[[ATEN_FULL:.*]] = torch.aten.full %[[DIM_LIST]], %[[FILL_VAL]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],si64> + %cst = torch.vtensor.literal(dense<[2,3,4]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = "torch.operator"(%cst) <{name = "onnx.ConstantOfShape"}> {torch.onnx.value = dense<3> : tensor<1xsi64>}: (!torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,4], si64> + return %0 : !torch.vtensor<[2,3,4], si64> +}