mirror of https://github.com/llvm/torch-mlir
Fix Base Lazy Backend Type Conversion (#1412)
* Fix c10::prim::Constant conversion; Added CAPI for passes; Added passes to base lazy backend * Update ivalue_importer to use ImportOptions; Added tests for non-value/value tensor types * Added tests for scalar Constant import; Updated MB::importFunction to use ImportOptions * Test updates * Move back module variable name * Remove RefineTypes from TorchMlirLoweringContext::Build() * Rename pass; Remove passes from base lazy backend * Rename pass to VerifyBackendContractPass * Aligned cmd pass name; Fixed TorchConversion passes registrationpull/1214/head snapshot-20221005.617
parent
eda18e351c
commit
708fa346a6
|
@ -204,6 +204,10 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
|
|||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
|
||||
/// Gets the !torch.vtensor type with the tensor attribute.
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.none type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
//===-- torch-mlir-c/Transforms.h - C API for torch passes --------*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header declares the registration and creation method for
|
||||
// transformation passes.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_C_TRANSFORMS_H
|
||||
#define TORCHMLIR_C_TRANSFORMS_H
|
||||
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
|
||||
|
||||
#endif // TORCHMLIR_C_TRANSFORMS_H
|
|
@ -1,5 +1,7 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header)
|
||||
mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl)
|
||||
add_public_tablegen_target(TorchMLIRTorchPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc)
|
||||
|
|
|
@ -21,6 +21,8 @@ class ModuleOp;
|
|||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
@ -109,6 +111,8 @@ std::unique_ptr<OperationPass<ModuleOp>>
|
|||
createLowerToBackendContractPass(int maxIterations, bool decompose,
|
||||
ArrayRef<std::string> backendLegalOps);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
|
||||
|
||||
StringRef getShapeLibrary();
|
||||
|
||||
} // namespace Torch
|
||||
|
@ -116,6 +120,13 @@ StringRef getShapeLibrary();
|
|||
/// Registers all Torch transformation passes.
|
||||
void registerTorchPasses();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -329,4 +329,16 @@ def LowerToBackendContract
|
|||
let dependentDialects = ["func::FuncDialect"];
|
||||
}
|
||||
|
||||
def VerifyBackendContract
|
||||
: Pass<"torch-verify-backend-contract", "ModuleOp"> {
|
||||
let summary = "Check that program satisfies backend contract.";
|
||||
let constructor =
|
||||
"mlir::torch::Torch::createVerifyBackendContractPass()";
|
||||
let description = [{
|
||||
This pass performs a set of inspections to check that program satisfies backend
|
||||
contract. In case of check failure it prints out the error message and returns
|
||||
`signalPassFailure()` status.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_TORCH_PASSES
|
||||
|
|
|
@ -3,6 +3,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
|
|||
Registration.cpp
|
||||
TorchOps.cpp
|
||||
TorchTypes.cpp
|
||||
Transforms.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
|
||||
|
@ -16,6 +17,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
|
|||
MLIRSupport
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRInitAll
|
||||
TorchMLIRTorchPasses
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRCAPI)
|
||||
|
|
|
@ -246,6 +246,14 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
|
|||
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
||||
auto attrTensorType =
|
||||
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
|
||||
return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(),
|
||||
attrTensorType.getShape(),
|
||||
attrTensorType.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.none type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
//===- CAPIPasses.cpp - C API for Transformations Passes ------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/CAPI/Pass.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
// Must include the declarations as they carry important visibility attributes.
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.cpp.inc"
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -242,6 +242,16 @@ public:
|
|||
});
|
||||
}
|
||||
};
|
||||
|
||||
class VerifyBackendContractPass
|
||||
: public VerifyBackendContractBase<VerifyBackendContractPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
if (!satisfiesBackendContract(getOperation(), /*actuallyEmitDiagnostics=*/true)) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
@ -250,3 +260,8 @@ mlir::torch::Torch::createLowerToBackendContractPass(
|
|||
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose,
|
||||
backendLegalOps);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::Torch::createVerifyBackendContractPass() {
|
||||
return std::make_unique<VerifyBackendContractPass>();
|
||||
}
|
||||
|
|
|
@ -11,17 +11,8 @@
|
|||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
} // end namespace
|
||||
|
||||
void mlir::torch::registerTorchPasses() {
|
||||
::registerPasses();
|
||||
mlir::torch::registerPasses();
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torchscript-module-to-torch-backend-pipeline",
|
||||
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
||||
|
|
|
@ -34,13 +34,13 @@ using namespace mlir::tosa;
|
|||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
namespace reg {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc"
|
||||
} // end namespace
|
||||
} // end namespace reg
|
||||
|
||||
void mlir::torch::registerTorchConversionPasses() {
|
||||
::registerPasses();
|
||||
reg::registerPasses();
|
||||
mlir::PassPipelineRegistration<>(
|
||||
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
|
||||
|
|
|
@ -276,7 +276,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.ListConstruct", loc,
|
||||
torchMlirTorchListTypeGet(
|
||||
getMlirTypeFromTorchType(loc, list.elementType())),
|
||||
getMlirTypeFromTorchType(loc, list.elementType(), importOptions)),
|
||||
elems);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
|
@ -291,8 +291,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.DictConstruct", loc,
|
||||
torchMlirTorchDictTypeGet(
|
||||
getMlirTypeFromTorchType(loc, dict.keyType()),
|
||||
getMlirTypeFromTorchType(loc, dict.valueType())),
|
||||
getMlirTypeFromTorchType(loc, dict.keyType(), importOptions),
|
||||
getMlirTypeFromTorchType(loc, dict.valueType(), importOptions)),
|
||||
keys, values);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
|
@ -368,10 +368,20 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
at::Tensor tensor = ivalue.toTensor().contiguous();
|
||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
|
||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
MlirOperation tensorOp;
|
||||
|
||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||
tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.vtensor.literal", loc,
|
||||
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
} else {
|
||||
tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
}
|
||||
|
||||
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||
|
||||
// Construct the complete tensor value. This is trivial for most tensors, but
|
||||
|
@ -384,9 +394,16 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
// compiler stages that are building a statically modeled quantization
|
||||
// representation will need to convert this to their representation.
|
||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||
MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
|
||||
context, shape.size(), shape.data(),
|
||||
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||
MlirType quantizedTensorType;
|
||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||
quantizedTensorType = torchMlirTorchValueTensorTypeGet(
|
||||
context, shape.size(), shape.data(),
|
||||
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||
} else {
|
||||
quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
|
||||
context, shape.size(), shape.data(),
|
||||
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||
}
|
||||
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
||||
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
|
||||
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
|
||||
|
@ -463,7 +480,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
|||
"name", mlirStringAttrGet(
|
||||
context, toMlirStringRef(classAttribute.getName()))),
|
||||
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
||||
loc, classAttribute.getType()))),
|
||||
loc, classAttribute.getType(), importOptions))),
|
||||
isPrivate);
|
||||
}
|
||||
|
||||
|
|
|
@ -124,10 +124,16 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
|||
}
|
||||
|
||||
torch::jit::StrongFunctionPtr
|
||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function,
|
||||
py::object maybeImportOptions) {
|
||||
ImportOptions importOptions;
|
||||
if (!maybeImportOptions.is_none()) {
|
||||
importOptions = py::cast<ImportOptions>(maybeImportOptions);
|
||||
}
|
||||
MlirBlock block = getBodyBlock();
|
||||
MlirOperation terminator = this->terminator;
|
||||
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_);
|
||||
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_,
|
||||
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
|
||||
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
||||
return function;
|
||||
}
|
||||
|
@ -182,7 +188,8 @@ void ModuleBuilder::bind(py::module &m) {
|
|||
.def(py::init<py::object>(), py::arg("context") = py::none())
|
||||
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
||||
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
||||
.def("import_function", &ModuleBuilder::importFunction)
|
||||
.def("import_function", &ModuleBuilder::importFunction, py::arg("function"),
|
||||
py::arg("importOptions") = py::none())
|
||||
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
|
||||
py::arg("classAnnotator") = py::none(),
|
||||
py::arg("importOptions") = py::none());
|
||||
|
|
|
@ -39,7 +39,8 @@ public:
|
|||
// Just a bit of naming cruft.
|
||||
// Returns the same function, making it suitable as a nested decorator.
|
||||
torch::jit::StrongFunctionPtr
|
||||
importFunction(torch::jit::StrongFunctionPtr function);
|
||||
importFunction(torch::jit::StrongFunctionPtr function,
|
||||
py::object maybeImportOptions);
|
||||
|
||||
// Imports a torch::jit::Module into the current module, using the
|
||||
// annotations, if not none, provided in `maybeClassAnnotator` which should be
|
||||
|
|
|
@ -198,10 +198,17 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
|||
c10::attr::value)))));
|
||||
} else if (output->type()->cast<c10::TensorType>()) {
|
||||
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
|
||||
op = createMlirOperation(
|
||||
"torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
|
||||
toMlirNamedAttribute("value", attr));
|
||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||
op = createMlirOperation(
|
||||
"torch.vtensor.literal", loc,
|
||||
torchMlirTorchValueTensorTypeGetFromAttribute(attr),
|
||||
toMlirNamedAttribute("value", attr));
|
||||
} else {
|
||||
op = createMlirOperation(
|
||||
"torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
|
||||
toMlirNamedAttribute("value", attr));
|
||||
}
|
||||
} else if (output->type()->cast<c10::DeviceObjType>()) {
|
||||
op = createMlirOperation(
|
||||
"torch.constant.device", loc,
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE.pytorch for license information.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ones_i32 = torch.ones(1, dtype=torch.int32)
|
||||
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
|
||||
self.arange = torch.nn.Parameter(torch.arange(3.0))
|
||||
|
||||
# CHECK: %[[ARANGE:.*]] = torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||
# CHECK: %[[ONES_I32:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi32>) : !torch.vtensor<[1],si32>
|
||||
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi8>) : !torch.vtensor<[1],si8>
|
||||
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
|
||||
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
|
||||
# CHECK: %[[ONES_QINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.vtensor<[1],si8>, !torch.float, !torch.int -> !torch.vtensor<[1],!torch.qint8>
|
||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.vtensor<[3],f32>
|
||||
# CHECK: torch.slot "ones_i32", %[[ONES_I32]] : !torch.vtensor<[1],si32>
|
||||
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.vtensor<[1],!torch.qint8>
|
||||
# CHECK: }
|
||||
test_module = TestModule()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
|
||||
import_options = ImportOptions()
|
||||
import_options.assumeTensorsHaveValueSemantics = True
|
||||
|
||||
class_annotator = ClassAnnotator()
|
||||
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c, class_annotator, import_options)
|
||||
mb.module.operation.print()
|
|
@ -5,7 +5,7 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
|
||||
from utils import create_script_function
|
||||
|
||||
|
@ -162,3 +162,34 @@ graph():
|
|||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar() -> !torch.number {
|
||||
# CHECK: %[[A:.*]] = torch.tensor.literal
|
||||
# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit
|
||||
# CHECK: return %[[RET]] : !torch.number
|
||||
import_options = ImportOptions()
|
||||
import_options.assumeTensorsHaveValueSemantics = False
|
||||
mb.import_function(create_script_function("__torch__.prim_Constant_scalar", """
|
||||
graph():
|
||||
%0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
|
||||
%1 : Scalar = aten::ScalarImplicit(%0)
|
||||
return (%1)
|
||||
""", parse_tensor_constants=True), import_options)
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar_value_semantics() -> !torch.number {
|
||||
# CHECK: %[[A:.*]] = torch.vtensor.literal
|
||||
# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit
|
||||
# CHECK: return %[[RET]] : !torch.number
|
||||
import_options.assumeTensorsHaveValueSemantics = True
|
||||
mb.import_function(create_script_function("__torch__.prim_Constant_scalar_value_semantics", """
|
||||
graph():
|
||||
%0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
|
||||
%1 : Scalar = aten::ScalarImplicit(%0)
|
||||
return (%1)
|
||||
""", parse_tensor_constants=True), import_options)
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -10,6 +10,6 @@ from torch._C import CompilationUnit
|
|||
# RUN: %PYTHON %s
|
||||
|
||||
# Import TorchScript IR string as ScriptFunction.
|
||||
def create_script_function(func_name, ts_ir_str):
|
||||
def create_script_function(func_name, ts_ir_str, **kwargs):
|
||||
cu = CompilationUnit()
|
||||
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))
|
||||
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str, **kwargs))
|
||||
|
|
Loading…
Reference in New Issue