Drop torch attributes at the end of backend conversion. (#2876)

Fixes https://github.com/llvm/torch-mlir/issues/2866

Some backends / downstream projects expect that a "fully converted"
program has no remaining ops or attributes from the original dialect(s).
pull/2909/head
Scott Todd 2024-02-13 14:32:02 -08:00 committed by GitHub
parent 24c2fc0b5f
commit d6e1d836ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 33 additions and 1 deletions

View File

@ -115,6 +115,21 @@ static void setupFinalization(ConversionTarget &target,
setupFinalization<OpTy2, OpTys...>(target, patterns, typeConverter);
}
static void stripTorchAttrs(func::FuncOp func) {
bool modified = false;
SmallVector<NamedAttribute> newAttrs;
for (auto attr : func->getDialectAttrs()) {
if (attr.getName().getValue().starts_with("torch."))
modified = true;
else
newAttrs.push_back(attr);
}
if (modified)
func->setDialectAttrs(newAttrs);
// Note: this could also strip "arg" and "result" attrs if they were used.
}
namespace {
struct FinalizingBackendTypeConversionPass
: public FinalizingBackendTypeConversionBase<
@ -151,6 +166,9 @@ struct FinalizingBackendTypeConversionPass
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
// Drop attributes that are no longer used after conversion out of Torch.
stripTorchAttrs(func);
}
};
} // namespace

View File

@ -30,7 +30,7 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
// -----
// CHECK-LABEL: func.func @torch.aten.matmul.2d
func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> {
// CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32>
// CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32>
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32

View File

@ -54,6 +54,20 @@ func.func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 {
// -----
// CHECK-LABEL: func.func @eliminate_attributes()
// CHECK-NOT: attributes
// CHECK-NOT: torch.onnx_meta
func.func @eliminate_attributes() attributes {
torch.onnx_meta.ir_version = 8 : si64,
torch.onnx_meta.opset_version = 17 : si64,
torch.onnx_meta.producer_name = "pytorch",
torch.onnx_meta.producer_version = "2.1.0"
} {
return
}
// -----
func.func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
// expected-error @+1 {{failed to legalize operation 'test.source'}}
%0 = "test.source"() : () -> !torch.vtensor<[],f32>