Fix Base Lazy Backend Type Conversion (#1412)

* Fix c10::prim::Constant conversion; Added CAPI for passes; Added passes to base lazy backend

* Update ivalue_importer to use ImportOptions; Added tests for non-value/value tensor types

* Added tests for scalar Constant import; Updated MB::importFunction to use ImportOptions

* Test updates

* Move back module variable name

* Remove RefineTypes from TorchMlirLoweringContext::Build()

* Rename pass; Remove passes from base lazy backend

* Rename pass to VerifyBackendContractPass

* Aligned cmd pass name; Fixed TorchConversion passes registration
pull/1214/head snapshot-20221005.617
Gleb Kazantaev 2022-10-04 18:53:28 -04:00 committed by GitHub
parent eda18e351c
commit 708fa346a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 234 additions and 35 deletions

View File

@ -204,6 +204,10 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context); torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
/// Gets the !torch.vtensor type with the tensor attribute.
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.none type. // !torch.none type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,22 @@
//===-- torch-mlir-c/Transforms.h - C API for torch passes --------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header declares the registration and creation method for
// transformation passes.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_C_TRANSFORMS_H
#define TORCHMLIR_C_TRANSFORMS_H
#include "mlir-c/Support.h"
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
#endif // TORCHMLIR_C_TRANSFORMS_H

View File

@ -1,5 +1,7 @@
set(LLVM_TARGET_DEFINITIONS Passes.td) set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls) mlir_tablegen(Passes.h.inc -gen-pass-decls)
mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header)
mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl)
add_public_tablegen_target(TorchMLIRTorchPassIncGen) add_public_tablegen_target(TorchMLIRTorchPassIncGen)
add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc) add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc)

View File

@ -21,6 +21,8 @@ class ModuleOp;
namespace torch { namespace torch {
namespace Torch { namespace Torch {
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass(); std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
@ -109,6 +111,8 @@ std::unique_ptr<OperationPass<ModuleOp>>
createLowerToBackendContractPass(int maxIterations, bool decompose, createLowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps); ArrayRef<std::string> backendLegalOps);
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
StringRef getShapeLibrary(); StringRef getShapeLibrary();
} // namespace Torch } // namespace Torch
@ -116,6 +120,13 @@ StringRef getShapeLibrary();
/// Registers all Torch transformation passes. /// Registers all Torch transformation passes.
void registerTorchPasses(); void registerTorchPasses();
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -329,4 +329,16 @@ def LowerToBackendContract
let dependentDialects = ["func::FuncDialect"]; let dependentDialects = ["func::FuncDialect"];
} }
def VerifyBackendContract
: Pass<"torch-verify-backend-contract", "ModuleOp"> {
let summary = "Check that program satisfies backend contract.";
let constructor =
"mlir::torch::Torch::createVerifyBackendContractPass()";
let description = [{
This pass performs a set of inspections to check that program satisfies backend
contract. In case of check failure it prints out the error message and returns
`signalPassFailure()` status.
}];
}
#endif // TORCHMLIR_TORCH_PASSES #endif // TORCHMLIR_TORCH_PASSES

View File

@ -3,6 +3,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
Registration.cpp Registration.cpp
TorchOps.cpp TorchOps.cpp
TorchTypes.cpp TorchTypes.cpp
Transforms.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/ ${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
@ -16,6 +17,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
MLIRSupport MLIRSupport
TorchMLIRTorchDialect TorchMLIRTorchDialect
TorchMLIRInitAll TorchMLIRInitAll
TorchMLIRTorchPasses
) )
torch_mlir_target_includes(TorchMLIRCAPI) torch_mlir_target_includes(TorchMLIRCAPI)

View File

@ -246,6 +246,14 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context))); Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
} }
MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType =
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(),
attrTensorType.getElementType()));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.none type. // torch.none type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,27 @@
//===- CAPIPasses.cpp - C API for Transformations Passes ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/CAPI/Pass.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
// Must include the declarations as they carry important visibility attributes.
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
#ifdef __cplusplus
extern "C" {
#endif
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.cpp.inc"
#ifdef __cplusplus
}
#endif

View File

@ -242,6 +242,16 @@ public:
}); });
} }
}; };
class VerifyBackendContractPass
: public VerifyBackendContractBase<VerifyBackendContractPass> {
public:
void runOnOperation() override {
if (!satisfiesBackendContract(getOperation(), /*actuallyEmitDiagnostics=*/true)) {
return signalPassFailure();
}
}
};
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
@ -250,3 +260,8 @@ mlir::torch::Torch::createLowerToBackendContractPass(
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose, return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose,
backendLegalOps); backendLegalOps);
} }
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createVerifyBackendContractPass() {
return std::make_unique<VerifyBackendContractPass>();
}

View File

@ -11,17 +11,8 @@
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // end namespace
void mlir::torch::registerTorchPasses() { void mlir::torch::registerTorchPasses() {
::registerPasses(); mlir::torch::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-module-to-torch-backend-pipeline", "torchscript-module-to-torch-backend-pipeline",
"Pipeline lowering TorchScript object graph IR to Torch backend form.", "Pipeline lowering TorchScript object graph IR to Torch backend form.",

View File

@ -34,13 +34,13 @@ using namespace mlir::tosa;
// Pass registration // Pass registration
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace { namespace reg {
#define GEN_PASS_REGISTRATION #define GEN_PASS_REGISTRATION
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc"
} // end namespace } // end namespace reg
void mlir::torch::registerTorchConversionPasses() { void mlir::torch::registerTorchConversionPasses() {
::registerPasses(); reg::registerPasses();
mlir::PassPipelineRegistration<>( mlir::PassPipelineRegistration<>(
"torch-backend-to-linalg-on-tensors-backend-pipeline", "torch-backend-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering torch backend contract to linalg-on-tensors backend " "Pipeline lowering torch backend contract to linalg-on-tensors backend "

View File

@ -276,7 +276,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.ListConstruct", loc, importBlock, "torch.prim.ListConstruct", loc,
torchMlirTorchListTypeGet( torchMlirTorchListTypeGet(
getMlirTypeFromTorchType(loc, list.elementType())), getMlirTypeFromTorchType(loc, list.elementType(), importOptions)),
elems); elems);
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
@ -291,8 +291,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.DictConstruct", loc, importBlock, "torch.prim.DictConstruct", loc,
torchMlirTorchDictTypeGet( torchMlirTorchDictTypeGet(
getMlirTypeFromTorchType(loc, dict.keyType()), getMlirTypeFromTorchType(loc, dict.keyType(), importOptions),
getMlirTypeFromTorchType(loc, dict.valueType())), getMlirTypeFromTorchType(loc, dict.valueType(), importOptions)),
keys, values); keys, values);
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
@ -368,10 +368,20 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
at::Tensor tensor = ivalue.toTensor().contiguous(); at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc); MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp = createMlirOperationAtEnd( MlirOperation tensorOp;
importBlock, "torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements), if (importOptions.assumeTensorsHaveValueSemantics) {
toMlirNamedAttribute("value", denseElements)); tensorOp = createMlirOperationAtEnd(
importBlock, "torch.vtensor.literal", loc,
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
} else {
tensorOp = createMlirOperationAtEnd(
importBlock, "torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
}
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0); MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
// Construct the complete tensor value. This is trivial for most tensors, but // Construct the complete tensor value. This is trivial for most tensors, but
@ -384,9 +394,16 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
// compiler stages that are building a statically modeled quantization // compiler stages that are building a statically modeled quantization
// representation will need to convert this to their representation. // representation will need to convert this to their representation.
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end()); std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet( MlirType quantizedTensorType;
context, shape.size(), shape.data(), if (importOptions.assumeTensorsHaveValueSemantics) {
getMlirTypeForTorchScalarType(loc, tensor.scalar_type())); quantizedTensorType = torchMlirTorchValueTensorTypeGet(
context, shape.size(), shape.data(),
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
} else {
quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
context, shape.size(), shape.data(),
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
}
if (tensor.qscheme() == c10::kPerTensorAffine) { if (tensor.qscheme() == c10::kPerTensorAffine) {
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale())); MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point())); MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
@ -463,7 +480,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
"name", mlirStringAttrGet( "name", mlirStringAttrGet(
context, toMlirStringRef(classAttribute.getName()))), context, toMlirStringRef(classAttribute.getName()))),
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType( toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
loc, classAttribute.getType()))), loc, classAttribute.getType(), importOptions))),
isPrivate); isPrivate);
} }

View File

@ -124,10 +124,16 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
} }
torch::jit::StrongFunctionPtr torch::jit::StrongFunctionPtr
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) { ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function,
py::object maybeImportOptions) {
ImportOptions importOptions;
if (!maybeImportOptions.is_none()) {
importOptions = py::cast<ImportOptions>(maybeImportOptions);
}
MlirBlock block = getBodyBlock(); MlirBlock block = getBodyBlock();
MlirOperation terminator = this->terminator; MlirOperation terminator = this->terminator;
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_); MlirOperation func = importJitFunctionAsFuncOp(context, function.function_,
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
mlirBlockInsertOwnedOperationBefore(block, terminator, func); mlirBlockInsertOwnedOperationBefore(block, terminator, func);
return function; return function;
} }
@ -182,7 +188,8 @@ void ModuleBuilder::bind(py::module &m) {
.def(py::init<py::object>(), py::arg("context") = py::none()) .def(py::init<py::object>(), py::arg("context") = py::none())
.def_property_readonly("context", &ModuleBuilder::getContextObj) .def_property_readonly("context", &ModuleBuilder::getContextObj)
.def_property_readonly("module", &ModuleBuilder::getModuleObj) .def_property_readonly("module", &ModuleBuilder::getModuleObj)
.def("import_function", &ModuleBuilder::importFunction) .def("import_function", &ModuleBuilder::importFunction, py::arg("function"),
py::arg("importOptions") = py::none())
.def("import_module", &ModuleBuilder::importModule, py::arg("module"), .def("import_module", &ModuleBuilder::importModule, py::arg("module"),
py::arg("classAnnotator") = py::none(), py::arg("classAnnotator") = py::none(),
py::arg("importOptions") = py::none()); py::arg("importOptions") = py::none());

View File

@ -39,7 +39,8 @@ public:
// Just a bit of naming cruft. // Just a bit of naming cruft.
// Returns the same function, making it suitable as a nested decorator. // Returns the same function, making it suitable as a nested decorator.
torch::jit::StrongFunctionPtr torch::jit::StrongFunctionPtr
importFunction(torch::jit::StrongFunctionPtr function); importFunction(torch::jit::StrongFunctionPtr function,
py::object maybeImportOptions);
// Imports a torch::jit::Module into the current module, using the // Imports a torch::jit::Module into the current module, using the
// annotations, if not none, provided in `maybeClassAnnotator` which should be // annotations, if not none, provided in `maybeClassAnnotator` which should be

View File

@ -198,10 +198,17 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
c10::attr::value))))); c10::attr::value)))));
} else if (output->type()->cast<c10::TensorType>()) { } else if (output->type()->cast<c10::TensorType>()) {
MlirAttribute attr = importAttribute(loc, node, c10::attr::value); MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
op = createMlirOperation( if (importOptions.assumeTensorsHaveValueSemantics) {
"torch.tensor.literal", loc, op = createMlirOperation(
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr), "torch.vtensor.literal", loc,
toMlirNamedAttribute("value", attr)); torchMlirTorchValueTensorTypeGetFromAttribute(attr),
toMlirNamedAttribute("value", attr));
} else {
op = createMlirOperation(
"torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
toMlirNamedAttribute("value", attr));
}
} else if (output->type()->cast<c10::DeviceObjType>()) { } else if (output->type()->cast<c10::DeviceObjType>()) {
op = createMlirOperation( op = createMlirOperation(
"torch.constant.device", loc, "torch.constant.device", loc,

View File

@ -0,0 +1,42 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE.pytorch for license information.
import typing
import torch
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
mb = ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ones_i32 = torch.ones(1, dtype=torch.int32)
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
self.arange = torch.nn.Parameter(torch.arange(3.0))
# CHECK: %[[ARANGE:.*]] = torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.vtensor<[3],f32>
# CHECK: %[[ONES_I32:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi32>) : !torch.vtensor<[1],si32>
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi8>) : !torch.vtensor<[1],si8>
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
# CHECK: %[[ONES_QINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.vtensor<[1],si8>, !torch.float, !torch.int -> !torch.vtensor<[1],!torch.qint8>
# CHECK: %[[ROOT:.*]] = torch.nn_module {
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.vtensor<[3],f32>
# CHECK: torch.slot "ones_i32", %[[ONES_I32]] : !torch.vtensor<[1],si32>
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.vtensor<[1],!torch.qint8>
# CHECK: }
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
import_options = ImportOptions()
import_options.assumeTensorsHaveValueSemantics = True
class_annotator = ClassAnnotator()
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, class_annotator, import_options)
mb.module.operation.print()

View File

@ -5,7 +5,7 @@
import typing import typing
import torch import torch
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
from utils import create_script_function from utils import create_script_function
@ -162,3 +162,34 @@ graph():
mb.module.operation.print() mb.module.operation.print()
print() print()
# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar() -> !torch.number {
# CHECK: %[[A:.*]] = torch.tensor.literal
# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit
# CHECK: return %[[RET]] : !torch.number
import_options = ImportOptions()
import_options.assumeTensorsHaveValueSemantics = False
mb.import_function(create_script_function("__torch__.prim_Constant_scalar", """
graph():
%0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
%1 : Scalar = aten::ScalarImplicit(%0)
return (%1)
""", parse_tensor_constants=True), import_options)
mb.module.operation.print()
print()
# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar_value_semantics() -> !torch.number {
# CHECK: %[[A:.*]] = torch.vtensor.literal
# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit
# CHECK: return %[[RET]] : !torch.number
import_options.assumeTensorsHaveValueSemantics = True
mb.import_function(create_script_function("__torch__.prim_Constant_scalar_value_semantics", """
graph():
%0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
%1 : Scalar = aten::ScalarImplicit(%0)
return (%1)
""", parse_tensor_constants=True), import_options)
mb.module.operation.print()
print()

View File

@ -10,6 +10,6 @@ from torch._C import CompilationUnit
# RUN: %PYTHON %s # RUN: %PYTHON %s
# Import TorchScript IR string as ScriptFunction. # Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str): def create_script_function(func_name, ts_ir_str, **kwargs):
cu = CompilationUnit() cu = CompilationUnit()
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str)) return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str, **kwargs))