[onnx] Convert `onnx.constant` to `torch` literal tensor (#2748)

Handles the multiple cases of `onnx` constant values and converts them
to `torch` literal tensors. This can include splats with a single
integer or floating point value, a set of explicit integer values, or
an elements array attr of values.
pull/2758/head
Rob Suderman 2024-01-15 09:31:22 -08:00 committed by GitHub
parent 10acea71be
commit 197b3b475c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 107 additions and 0 deletions

View File

@ -190,6 +190,19 @@ struct OpBinder {
return failure();
}
ParseResult denseElementsAttr(ElementsAttr elementsattr,
StringRef nameSuffix) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
Attribute attr = op->getAttr(name);
if (!attr || !isa<ElementsAttr>(attr)) {
return failure();
}
elementsattr = cast<ElementsAttr>(attr);
return success();
}
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
std::string defaultValue = "") {
SmallString<64> name("torch.onnx.");

View File

@ -590,6 +590,59 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
tensorList, cstDim);
return success();
});
patterns.onOp(
"Constant", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
if (binder.tensorResultType(resultType))
return failure();
auto dtype = resultType.getDtype();
Value scalarValue;
float floatValue;
if (binder.op->hasAttr("torch.onnx.value_float") &&
!binder.f32FloatAttr(floatValue, "value_float", 0.0)) {
auto splatAttr =
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
rewriter.getFloatAttr(dtype, floatValue));
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, splatAttr);
return success();
}
int64_t intValue;
if (binder.op->hasAttr("torch.onnx.value_int") &&
!binder.s64IntegerAttr(intValue, "value_int", 0)) {
auto splatAttr =
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
rewriter.getIntegerAttr(dtype, intValue));
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, splatAttr);
return success();
}
if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value")
.dyn_cast_or_null<ElementsAttr>()) {
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, attr);
return success();
}
llvm::SmallVector<int64_t> intValues;
if (!binder.s64IntegerArrayAttr(intValues, "value_ints", {}) &&
!intValues.empty()) {
llvm::SmallVector<APInt> apValues;
for (auto intVal : intValues) {
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal));
}
auto attr = DenseElementsAttr::get(
resultType.toBuiltinTensor().clone(dtype), apValues);
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, attr);
return success();
}
return failure();
});
patterns.onOp(
"Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;

View File

@ -979,3 +979,44 @@ func.func @test_depthtospace_crd_mode_example(%arg0: !torch.vtensor<[1,8,2,3],f3
%0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "CRD"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32>
return %0 : !torch.vtensor<[1,2,4,6],f32>
}
// -----
// CHECK-LABEL: @float_constant
func.func @float_constant() -> !torch.vtensor<[], f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<2.500000e-01> : tensor<f32>) : !torch.vtensor<[],f32>
// CHECK: return %[[CST]]
%0 = torch.operator "onnx.Constant"() {torch.onnx.value_float = 0.25 : f32} : () -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
// CHECK-LABEL: @int_constant
func.func @int_constant() -> !torch.vtensor<[], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<79> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]]
%0 = torch.operator "onnx.Constant"() {torch.onnx.value_int = 79 : si64} : () -> !torch.vtensor<[],si64>
return %0 : !torch.vtensor<[],si64>
}
// -----
// CHECK-LABEL: @dense_constant
func.func @dense_constant() -> !torch.vtensor<[1], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<13> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
// CHECK: return %[[CST]]
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<13> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
return %0 : !torch.vtensor<[1],si64>
}
// -----
// CHECK-LABEL: @ints_constant
func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[7, 9]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
// CHECK: return %[[CST]]
%0 = "torch.operator"() <{name = "onnx.Constant"}> {torch.onnx.value_ints = [7 : si64, 9 : si64]} : () -> !torch.vtensor<[2],si64>
return %0 : !torch.vtensor<[2],si64>
}