mirror of https://github.com/llvm/torch-mlir
Fix Base Lazy Backend Type Conversion (#1412)
* Fix c10::prim::Constant conversion; Added CAPI for passes; Added passes to base lazy backend * Update ivalue_importer to use ImportOptions; Added tests for non-value/value tensor types * Added tests for scalar Constant import; Updated MB::importFunction to use ImportOptions * Test updates * Move back module variable name * Remove RefineTypes from TorchMlirLoweringContext::Build() * Rename pass; Remove passes from base lazy backend * Rename pass to VerifyBackendContractPass * Aligned cmd pass name; Fixed TorchConversion passes registrationpull/1214/head snapshot-20221005.617
parent
eda18e351c
commit
708fa346a6
|
@ -204,6 +204,10 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.vtensor type with the tensor attribute.
|
||||||
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
|
torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.none type.
|
// !torch.none type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
//===-- torch-mlir-c/Transforms.h - C API for torch passes --------*- C -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||||
|
// Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This header declares the registration and creation method for
|
||||||
|
// transformation passes.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_C_TRANSFORMS_H
|
||||||
|
#define TORCHMLIR_C_TRANSFORMS_H
|
||||||
|
|
||||||
|
#include "mlir-c/Support.h"
|
||||||
|
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_C_TRANSFORMS_H
|
|
@ -1,5 +1,7 @@
|
||||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header)
|
||||||
|
mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl)
|
||||||
add_public_tablegen_target(TorchMLIRTorchPassIncGen)
|
add_public_tablegen_target(TorchMLIRTorchPassIncGen)
|
||||||
|
|
||||||
add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc)
|
add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc)
|
||||||
|
|
|
@ -21,6 +21,8 @@ class ModuleOp;
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
|
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
@ -109,6 +111,8 @@ std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createLowerToBackendContractPass(int maxIterations, bool decompose,
|
createLowerToBackendContractPass(int maxIterations, bool decompose,
|
||||||
ArrayRef<std::string> backendLegalOps);
|
ArrayRef<std::string> backendLegalOps);
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
|
||||||
|
|
||||||
StringRef getShapeLibrary();
|
StringRef getShapeLibrary();
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
|
@ -116,6 +120,13 @@ StringRef getShapeLibrary();
|
||||||
/// Registers all Torch transformation passes.
|
/// Registers all Torch transformation passes.
|
||||||
void registerTorchPasses();
|
void registerTorchPasses();
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass registration
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -329,4 +329,16 @@ def LowerToBackendContract
|
||||||
let dependentDialects = ["func::FuncDialect"];
|
let dependentDialects = ["func::FuncDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def VerifyBackendContract
|
||||||
|
: Pass<"torch-verify-backend-contract", "ModuleOp"> {
|
||||||
|
let summary = "Check that program satisfies backend contract.";
|
||||||
|
let constructor =
|
||||||
|
"mlir::torch::Torch::createVerifyBackendContractPass()";
|
||||||
|
let description = [{
|
||||||
|
This pass performs a set of inspections to check that program satisfies backend
|
||||||
|
contract. In case of check failure it prints out the error message and returns
|
||||||
|
`signalPassFailure()` status.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCHMLIR_TORCH_PASSES
|
#endif // TORCHMLIR_TORCH_PASSES
|
||||||
|
|
|
@ -3,6 +3,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
|
||||||
Registration.cpp
|
Registration.cpp
|
||||||
TorchOps.cpp
|
TorchOps.cpp
|
||||||
TorchTypes.cpp
|
TorchTypes.cpp
|
||||||
|
Transforms.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
|
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
|
||||||
|
@ -16,6 +17,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
TorchMLIRTorchDialect
|
TorchMLIRTorchDialect
|
||||||
TorchMLIRInitAll
|
TorchMLIRInitAll
|
||||||
|
TorchMLIRTorchPasses
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_mlir_target_includes(TorchMLIRCAPI)
|
torch_mlir_target_includes(TorchMLIRCAPI)
|
||||||
|
|
|
@ -246,6 +246,14 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
|
||||||
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
||||||
|
auto attrTensorType =
|
||||||
|
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
|
||||||
|
return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(),
|
||||||
|
attrTensorType.getShape(),
|
||||||
|
attrTensorType.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.none type.
|
// torch.none type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
//===- CAPIPasses.cpp - C API for Transformations Passes ------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/CAPI/Pass.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
// Must include the declarations as they carry important visibility attributes.
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.cpp.inc"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -242,6 +242,16 @@ public:
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class VerifyBackendContractPass
|
||||||
|
: public VerifyBackendContractBase<VerifyBackendContractPass> {
|
||||||
|
public:
|
||||||
|
void runOnOperation() override {
|
||||||
|
if (!satisfiesBackendContract(getOperation(), /*actuallyEmitDiagnostics=*/true)) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
@ -250,3 +260,8 @@ mlir::torch::Torch::createLowerToBackendContractPass(
|
||||||
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose,
|
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose,
|
||||||
backendLegalOps);
|
backendLegalOps);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::torch::Torch::createVerifyBackendContractPass() {
|
||||||
|
return std::make_unique<VerifyBackendContractPass>();
|
||||||
|
}
|
||||||
|
|
|
@ -11,17 +11,8 @@
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Pass registration
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
#define GEN_PASS_REGISTRATION
|
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
|
||||||
} // end namespace
|
|
||||||
|
|
||||||
void mlir::torch::registerTorchPasses() {
|
void mlir::torch::registerTorchPasses() {
|
||||||
::registerPasses();
|
mlir::torch::registerPasses();
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torchscript-module-to-torch-backend-pipeline",
|
"torchscript-module-to-torch-backend-pipeline",
|
||||||
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
||||||
|
|
|
@ -34,13 +34,13 @@ using namespace mlir::tosa;
|
||||||
// Pass registration
|
// Pass registration
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace reg {
|
||||||
#define GEN_PASS_REGISTRATION
|
#define GEN_PASS_REGISTRATION
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc"
|
||||||
} // end namespace
|
} // end namespace reg
|
||||||
|
|
||||||
void mlir::torch::registerTorchConversionPasses() {
|
void mlir::torch::registerTorchConversionPasses() {
|
||||||
::registerPasses();
|
reg::registerPasses();
|
||||||
mlir::PassPipelineRegistration<>(
|
mlir::PassPipelineRegistration<>(
|
||||||
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||||
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
|
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
|
||||||
|
|
|
@ -276,7 +276,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.prim.ListConstruct", loc,
|
importBlock, "torch.prim.ListConstruct", loc,
|
||||||
torchMlirTorchListTypeGet(
|
torchMlirTorchListTypeGet(
|
||||||
getMlirTypeFromTorchType(loc, list.elementType())),
|
getMlirTypeFromTorchType(loc, list.elementType(), importOptions)),
|
||||||
elems);
|
elems);
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
|
@ -291,8 +291,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.prim.DictConstruct", loc,
|
importBlock, "torch.prim.DictConstruct", loc,
|
||||||
torchMlirTorchDictTypeGet(
|
torchMlirTorchDictTypeGet(
|
||||||
getMlirTypeFromTorchType(loc, dict.keyType()),
|
getMlirTypeFromTorchType(loc, dict.keyType(), importOptions),
|
||||||
getMlirTypeFromTorchType(loc, dict.valueType())),
|
getMlirTypeFromTorchType(loc, dict.valueType(), importOptions)),
|
||||||
keys, values);
|
keys, values);
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
|
@ -368,10 +368,20 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
at::Tensor tensor = ivalue.toTensor().contiguous();
|
at::Tensor tensor = ivalue.toTensor().contiguous();
|
||||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||||
|
|
||||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
MlirOperation tensorOp;
|
||||||
importBlock, "torch.tensor.literal", loc,
|
|
||||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
|
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||||
toMlirNamedAttribute("value", denseElements));
|
tensorOp = createMlirOperationAtEnd(
|
||||||
|
importBlock, "torch.vtensor.literal", loc,
|
||||||
|
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements),
|
||||||
|
toMlirNamedAttribute("value", denseElements));
|
||||||
|
} else {
|
||||||
|
tensorOp = createMlirOperationAtEnd(
|
||||||
|
importBlock, "torch.tensor.literal", loc,
|
||||||
|
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
|
||||||
|
toMlirNamedAttribute("value", denseElements));
|
||||||
|
}
|
||||||
|
|
||||||
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||||
|
|
||||||
// Construct the complete tensor value. This is trivial for most tensors, but
|
// Construct the complete tensor value. This is trivial for most tensors, but
|
||||||
|
@ -384,9 +394,16 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
// compiler stages that are building a statically modeled quantization
|
// compiler stages that are building a statically modeled quantization
|
||||||
// representation will need to convert this to their representation.
|
// representation will need to convert this to their representation.
|
||||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||||
MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
|
MlirType quantizedTensorType;
|
||||||
context, shape.size(), shape.data(),
|
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||||
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
quantizedTensorType = torchMlirTorchValueTensorTypeGet(
|
||||||
|
context, shape.size(), shape.data(),
|
||||||
|
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||||
|
} else {
|
||||||
|
quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
|
||||||
|
context, shape.size(), shape.data(),
|
||||||
|
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||||
|
}
|
||||||
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
||||||
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
|
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
|
||||||
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
|
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
|
||||||
|
@ -463,7 +480,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
||||||
"name", mlirStringAttrGet(
|
"name", mlirStringAttrGet(
|
||||||
context, toMlirStringRef(classAttribute.getName()))),
|
context, toMlirStringRef(classAttribute.getName()))),
|
||||||
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
||||||
loc, classAttribute.getType()))),
|
loc, classAttribute.getType(), importOptions))),
|
||||||
isPrivate);
|
isPrivate);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -124,10 +124,16 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::jit::StrongFunctionPtr
|
torch::jit::StrongFunctionPtr
|
||||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function,
|
||||||
|
py::object maybeImportOptions) {
|
||||||
|
ImportOptions importOptions;
|
||||||
|
if (!maybeImportOptions.is_none()) {
|
||||||
|
importOptions = py::cast<ImportOptions>(maybeImportOptions);
|
||||||
|
}
|
||||||
MlirBlock block = getBodyBlock();
|
MlirBlock block = getBodyBlock();
|
||||||
MlirOperation terminator = this->terminator;
|
MlirOperation terminator = this->terminator;
|
||||||
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_);
|
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_,
|
||||||
|
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
|
||||||
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
||||||
return function;
|
return function;
|
||||||
}
|
}
|
||||||
|
@ -182,7 +188,8 @@ void ModuleBuilder::bind(py::module &m) {
|
||||||
.def(py::init<py::object>(), py::arg("context") = py::none())
|
.def(py::init<py::object>(), py::arg("context") = py::none())
|
||||||
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
||||||
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
||||||
.def("import_function", &ModuleBuilder::importFunction)
|
.def("import_function", &ModuleBuilder::importFunction, py::arg("function"),
|
||||||
|
py::arg("importOptions") = py::none())
|
||||||
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
|
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
|
||||||
py::arg("classAnnotator") = py::none(),
|
py::arg("classAnnotator") = py::none(),
|
||||||
py::arg("importOptions") = py::none());
|
py::arg("importOptions") = py::none());
|
||||||
|
|
|
@ -39,7 +39,8 @@ public:
|
||||||
// Just a bit of naming cruft.
|
// Just a bit of naming cruft.
|
||||||
// Returns the same function, making it suitable as a nested decorator.
|
// Returns the same function, making it suitable as a nested decorator.
|
||||||
torch::jit::StrongFunctionPtr
|
torch::jit::StrongFunctionPtr
|
||||||
importFunction(torch::jit::StrongFunctionPtr function);
|
importFunction(torch::jit::StrongFunctionPtr function,
|
||||||
|
py::object maybeImportOptions);
|
||||||
|
|
||||||
// Imports a torch::jit::Module into the current module, using the
|
// Imports a torch::jit::Module into the current module, using the
|
||||||
// annotations, if not none, provided in `maybeClassAnnotator` which should be
|
// annotations, if not none, provided in `maybeClassAnnotator` which should be
|
||||||
|
|
|
@ -198,10 +198,17 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
c10::attr::value)))));
|
c10::attr::value)))));
|
||||||
} else if (output->type()->cast<c10::TensorType>()) {
|
} else if (output->type()->cast<c10::TensorType>()) {
|
||||||
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
|
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
|
||||||
op = createMlirOperation(
|
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||||
"torch.tensor.literal", loc,
|
op = createMlirOperation(
|
||||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
|
"torch.vtensor.literal", loc,
|
||||||
toMlirNamedAttribute("value", attr));
|
torchMlirTorchValueTensorTypeGetFromAttribute(attr),
|
||||||
|
toMlirNamedAttribute("value", attr));
|
||||||
|
} else {
|
||||||
|
op = createMlirOperation(
|
||||||
|
"torch.tensor.literal", loc,
|
||||||
|
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
|
||||||
|
toMlirNamedAttribute("value", attr));
|
||||||
|
}
|
||||||
} else if (output->type()->cast<c10::DeviceObjType>()) {
|
} else if (output->type()->cast<c10::DeviceObjType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.device", loc,
|
"torch.constant.device", loc,
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See LICENSE.pytorch for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ones_i32 = torch.ones(1, dtype=torch.int32)
|
||||||
|
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
|
||||||
|
self.arange = torch.nn.Parameter(torch.arange(3.0))
|
||||||
|
|
||||||
|
# CHECK: %[[ARANGE:.*]] = torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||||
|
# CHECK: %[[ONES_I32:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi32>) : !torch.vtensor<[1],si32>
|
||||||
|
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi8>) : !torch.vtensor<[1],si8>
|
||||||
|
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
|
||||||
|
# CHECK: %[[ONES_QINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.vtensor<[1],si8>, !torch.float, !torch.int -> !torch.vtensor<[1],!torch.qint8>
|
||||||
|
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||||
|
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.vtensor<[3],f32>
|
||||||
|
# CHECK: torch.slot "ones_i32", %[[ONES_I32]] : !torch.vtensor<[1],si32>
|
||||||
|
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.vtensor<[1],!torch.qint8>
|
||||||
|
# CHECK: }
|
||||||
|
test_module = TestModule()
|
||||||
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
|
||||||
|
import_options = ImportOptions()
|
||||||
|
import_options.assumeTensorsHaveValueSemantics = True
|
||||||
|
|
||||||
|
class_annotator = ClassAnnotator()
|
||||||
|
|
||||||
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
mb.import_module(recursivescriptmodule._c, class_annotator, import_options)
|
||||||
|
mb.module.operation.print()
|
|
@ -5,7 +5,7 @@
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||||
|
|
||||||
from utils import create_script_function
|
from utils import create_script_function
|
||||||
|
|
||||||
|
@ -162,3 +162,34 @@ graph():
|
||||||
|
|
||||||
mb.module.operation.print()
|
mb.module.operation.print()
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar() -> !torch.number {
|
||||||
|
# CHECK: %[[A:.*]] = torch.tensor.literal
|
||||||
|
# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit
|
||||||
|
# CHECK: return %[[RET]] : !torch.number
|
||||||
|
import_options = ImportOptions()
|
||||||
|
import_options.assumeTensorsHaveValueSemantics = False
|
||||||
|
mb.import_function(create_script_function("__torch__.prim_Constant_scalar", """
|
||||||
|
graph():
|
||||||
|
%0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
|
||||||
|
%1 : Scalar = aten::ScalarImplicit(%0)
|
||||||
|
return (%1)
|
||||||
|
""", parse_tensor_constants=True), import_options)
|
||||||
|
|
||||||
|
mb.module.operation.print()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar_value_semantics() -> !torch.number {
|
||||||
|
# CHECK: %[[A:.*]] = torch.vtensor.literal
|
||||||
|
# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit
|
||||||
|
# CHECK: return %[[RET]] : !torch.number
|
||||||
|
import_options.assumeTensorsHaveValueSemantics = True
|
||||||
|
mb.import_function(create_script_function("__torch__.prim_Constant_scalar_value_semantics", """
|
||||||
|
graph():
|
||||||
|
%0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
|
||||||
|
%1 : Scalar = aten::ScalarImplicit(%0)
|
||||||
|
return (%1)
|
||||||
|
""", parse_tensor_constants=True), import_options)
|
||||||
|
|
||||||
|
mb.module.operation.print()
|
||||||
|
print()
|
||||||
|
|
|
@ -10,6 +10,6 @@ from torch._C import CompilationUnit
|
||||||
# RUN: %PYTHON %s
|
# RUN: %PYTHON %s
|
||||||
|
|
||||||
# Import TorchScript IR string as ScriptFunction.
|
# Import TorchScript IR string as ScriptFunction.
|
||||||
def create_script_function(func_name, ts_ir_str):
|
def create_script_function(func_name, ts_ir_str, **kwargs):
|
||||||
cu = CompilationUnit()
|
cu = CompilationUnit()
|
||||||
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))
|
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str, **kwargs))
|
||||||
|
|
Loading…
Reference in New Issue