mirror of https://github.com/llvm/torch-mlir
Drop torch attributes at the end of backend conversion. (#2876)
Fixes https://github.com/llvm/torch-mlir/issues/2866 Some backends / downstream projects expect that a "fully converted" program has no remaining ops or attributes from the original dialect(s).pull/2909/head
parent
24c2fc0b5f
commit
d6e1d836ca
|
@ -115,6 +115,21 @@ static void setupFinalization(ConversionTarget &target,
|
|||
setupFinalization<OpTy2, OpTys...>(target, patterns, typeConverter);
|
||||
}
|
||||
|
||||
static void stripTorchAttrs(func::FuncOp func) {
|
||||
bool modified = false;
|
||||
SmallVector<NamedAttribute> newAttrs;
|
||||
for (auto attr : func->getDialectAttrs()) {
|
||||
if (attr.getName().getValue().starts_with("torch."))
|
||||
modified = true;
|
||||
else
|
||||
newAttrs.push_back(attr);
|
||||
}
|
||||
if (modified)
|
||||
func->setDialectAttrs(newAttrs);
|
||||
|
||||
// Note: this could also strip "arg" and "result" attrs if they were used.
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct FinalizingBackendTypeConversionPass
|
||||
: public FinalizingBackendTypeConversionBase<
|
||||
|
@ -151,6 +166,9 @@ struct FinalizingBackendTypeConversionPass
|
|||
|
||||
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
|
||||
// Drop attributes that are no longer used after conversion out of Torch.
|
||||
stripTorchAttrs(func);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -30,7 +30,7 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.matmul.2d
|
||||
func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> {
|
||||
// CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32>
|
||||
// CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32>
|
||||
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
|
||||
|
|
|
@ -54,6 +54,20 @@ func.func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @eliminate_attributes()
|
||||
// CHECK-NOT: attributes
|
||||
// CHECK-NOT: torch.onnx_meta
|
||||
func.func @eliminate_attributes() attributes {
|
||||
torch.onnx_meta.ir_version = 8 : si64,
|
||||
torch.onnx_meta.opset_version = 17 : si64,
|
||||
torch.onnx_meta.producer_name = "pytorch",
|
||||
torch.onnx_meta.producer_version = "2.1.0"
|
||||
} {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
|
||||
// expected-error @+1 {{failed to legalize operation 'test.source'}}
|
||||
%0 = "test.source"() : () -> !torch.vtensor<[],f32>
|
||||
|
|
Loading…
Reference in New Issue