mirror of https://github.com/llvm/torch-mlir
[onnx] Convert `onnx.constant` to `torch` literal tensor (#2748)
Handles the multiple cases of `onnx` constant values and converts them to `torch` literal tensors. This can include splats with a single integer or floating point value, a set of explicit integer values, or an elements array attr of values.pull/2758/head
parent
10acea71be
commit
197b3b475c
|
@ -190,6 +190,19 @@ struct OpBinder {
|
|||
return failure();
|
||||
}
|
||||
|
||||
ParseResult denseElementsAttr(ElementsAttr elementsattr,
|
||||
StringRef nameSuffix) {
|
||||
SmallString<64> name("torch.onnx.");
|
||||
name.append(nameSuffix);
|
||||
Attribute attr = op->getAttr(name);
|
||||
if (!attr || !isa<ElementsAttr>(attr)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
elementsattr = cast<ElementsAttr>(attr);
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
|
||||
std::string defaultValue = "") {
|
||||
SmallString<64> name("torch.onnx.");
|
||||
|
|
|
@ -590,6 +590,59 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
tensorList, cstDim);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Constant", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
if (binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
auto dtype = resultType.getDtype();
|
||||
Value scalarValue;
|
||||
|
||||
float floatValue;
|
||||
if (binder.op->hasAttr("torch.onnx.value_float") &&
|
||||
!binder.f32FloatAttr(floatValue, "value_float", 0.0)) {
|
||||
auto splatAttr =
|
||||
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
|
||||
rewriter.getFloatAttr(dtype, floatValue));
|
||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||
binder.op, resultType, splatAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t intValue;
|
||||
if (binder.op->hasAttr("torch.onnx.value_int") &&
|
||||
!binder.s64IntegerAttr(intValue, "value_int", 0)) {
|
||||
auto splatAttr =
|
||||
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
|
||||
rewriter.getIntegerAttr(dtype, intValue));
|
||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||
binder.op, resultType, splatAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value")
|
||||
.dyn_cast_or_null<ElementsAttr>()) {
|
||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||
binder.op, resultType, attr);
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> intValues;
|
||||
if (!binder.s64IntegerArrayAttr(intValues, "value_ints", {}) &&
|
||||
!intValues.empty()) {
|
||||
llvm::SmallVector<APInt> apValues;
|
||||
for (auto intVal : intValues) {
|
||||
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal));
|
||||
}
|
||||
auto attr = DenseElementsAttr::get(
|
||||
resultType.toBuiltinTensor().clone(dtype), apValues);
|
||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||
binder.op, resultType, attr);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
std::string autoPad;
|
||||
|
|
|
@ -979,3 +979,44 @@ func.func @test_depthtospace_crd_mode_example(%arg0: !torch.vtensor<[1,8,2,3],f3
|
|||
%0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "CRD"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32>
|
||||
return %0 : !torch.vtensor<[1,2,4,6],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @float_constant
|
||||
func.func @float_constant() -> !torch.vtensor<[], f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
|
||||
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<2.500000e-01> : tensor<f32>) : !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = torch.operator "onnx.Constant"() {torch.onnx.value_float = 0.25 : f32} : () -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @int_constant
|
||||
func.func @int_constant() -> !torch.vtensor<[], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
|
||||
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<79> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = torch.operator "onnx.Constant"() {torch.onnx.value_int = 79 : si64} : () -> !torch.vtensor<[],si64>
|
||||
return %0 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @dense_constant
|
||||
func.func @dense_constant() -> !torch.vtensor<[1], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
|
||||
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<13> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<13> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
|
||||
return %0 : !torch.vtensor<[1],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @ints_constant
|
||||
func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
|
||||
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[7, 9]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "torch.operator"() <{name = "onnx.Constant"}> {torch.onnx.value_ints = [7 : si64, 9 : si64]} : () -> !torch.vtensor<[2],si64>
|
||||
return %0 : !torch.vtensor<[2],si64>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue