diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index 5f3a2609b..5dd3d778f 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -115,6 +115,21 @@ static void setupFinalization(ConversionTarget &target, setupFinalization(target, patterns, typeConverter); } +static void stripTorchAttrs(func::FuncOp func) { + bool modified = false; + SmallVector newAttrs; + for (auto attr : func->getDialectAttrs()) { + if (attr.getName().getValue().starts_with("torch.")) + modified = true; + else + newAttrs.push_back(attr); + } + if (modified) + func->setDialectAttrs(newAttrs); + + // Note: this could also strip "arg" and "result" attrs if they were used. +} + namespace { struct FinalizingBackendTypeConversionPass : public FinalizingBackendTypeConversionBase< @@ -151,6 +166,9 @@ struct FinalizingBackendTypeConversionPass if (failed(applyFullConversion(func, target, std::move(patterns)))) signalPassFailure(); + + // Drop attributes that are no longer used after conversion out of Torch. + stripTorchAttrs(func); } }; } // namespace diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index cfb252cd1..f063f234e 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -30,7 +30,7 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v // ----- // CHECK-LABEL: func.func @torch.aten.matmul.2d -func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> { // CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32> // CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32> // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 diff --git a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir index a16da0932..46f80c06b 100644 --- a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir +++ b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir @@ -54,6 +54,20 @@ func.func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 { // ----- +// CHECK-LABEL: func.func @eliminate_attributes() +// CHECK-NOT: attributes +// CHECK-NOT: torch.onnx_meta +func.func @eliminate_attributes() attributes { + torch.onnx_meta.ir_version = 8 : si64, + torch.onnx_meta.opset_version = 17 : si64, + torch.onnx_meta.producer_name = "pytorch", + torch.onnx_meta.producer_version = "2.1.0" +} { + return +} + +// ----- + func.func @unable_to_convert_lone_buffer_cast() -> tensor { // expected-error @+1 {{failed to legalize operation 'test.source'}} %0 = "test.source"() : () -> !torch.vtensor<[],f32>