mirror of https://github.com/llvm/torch-mlir
Jacque PR
parent
957335f348
commit
ead5c42a53
|
@ -26,6 +26,11 @@ def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
|
|||
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchToMLProgram : Pass<"convert-torch-to-mlprogram", "ModuleOp"> {
|
||||
let summary = "Convert Torch ops to MLProgram ops";
|
||||
let constructor = "mlir::torch::createConvertTorchToMLProgramPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
|
||||
let summary = "Convert recognized Torch ops to Linalg ops";
|
||||
let description = [{
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTorchToMLProgramPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H
|
|
@ -302,6 +302,41 @@ def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [
|
|||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_ResourceNonValueTensorLiteralOp : Torch_Op<"tensor.resource.literal", [
|
||||
AllowsTypeRefinement,
|
||||
AllowedInModuleInitializer
|
||||
]> {
|
||||
let summary = "Create a value of !torch.tensor type from an external literal";
|
||||
let description = [{
|
||||
}];
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$sym_name,
|
||||
ElementsAttr:$value
|
||||
);
|
||||
let results = (outs Torch_NonValueTensorType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $sym_name `=` $value `)` attr-dict `:` qualified(type($result))
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_ResourceValueTensorLiteralOp : Torch_Op<"vtensor.resource.literal", [
|
||||
AllowsTypeRefinement,
|
||||
]> {
|
||||
let summary = "Create a value of !torch.vtensor type from an external literal";
|
||||
let description = [{
|
||||
}];
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$sym_name,
|
||||
ElementsAttr:$value
|
||||
);
|
||||
let results = (outs Torch_ValueTensorType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $sym_name `=` $value `)` attr-dict `:` qualified(type($result))
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
|
||||
Terminator,
|
||||
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_subdirectory(TorchToLinalg)
|
||||
add_subdirectory(TorchToSCF)
|
||||
add_subdirectory(TorchToArith)
|
||||
add_subdirectory(TorchToMLProgram)
|
||||
add_subdirectory(TorchToTosa)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
add_subdirectory(TorchToMhlo)
|
||||
|
@ -15,6 +16,7 @@ set(linked_libs TorchMLIRTorchToLinalg
|
|||
TorchMLIRTorchToArith
|
||||
TorchMLIRTorchToTosa
|
||||
TorchMLIRTorchToTMTensor
|
||||
TorchMLIRTorchToMLProgram
|
||||
TorchMLIRTorchConversionToMLProgram
|
||||
TorchMLIRConversionUtils)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchToMLProgram
|
||||
TorchToMLProgram.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMLProgram
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRFuncDialect
|
||||
MLIRMLProgramDialect
|
||||
TorchMLIRTorchDialect
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToMLProgram)
|
|
@ -0,0 +1,64 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
||||
namespace {
|
||||
class ConvertTorchToMLProgram
|
||||
: public ConvertTorchToMLProgramBase<ConvertTorchToMLProgram> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<ml_program::MLProgramDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget dummyTarget(*context);
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(dummyTarget, typeConverter);
|
||||
|
||||
auto globalBuilder =
|
||||
OpBuilder::atBlockBegin(&*module.getBodyRegion().begin());
|
||||
module.walk([&](Torch::ResourceValueTensorLiteralOp op) {
|
||||
auto type = typeConverter.convertType(op.getType());
|
||||
globalBuilder.create<ml_program::GlobalOp>(
|
||||
op.getLoc(), op.getSymNameAttr().getAttr(), type,
|
||||
/*is_mutable=*/true, // Just to enable generator
|
||||
/*value=*/op.getValue(),
|
||||
globalBuilder.getStringAttr("public"));
|
||||
OpBuilder builder(op);
|
||||
auto loadConst = builder.create<ml_program::GlobalLoadOp>(
|
||||
op.getLoc(), type, op.getSymNameAttr());
|
||||
Value torchTensor = builder.create<TorchConversion::FromBuiltinTensorOp>(
|
||||
op.getLoc(), op.getType(), loadConst);
|
||||
op.replaceAllUsesWith(torchTensor);
|
||||
op.erase();
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::createConvertTorchToMLProgramPass() {
|
||||
return std::make_unique<ConvertTorchToMLProgram>();
|
||||
}
|
|
@ -240,6 +240,19 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
reduceResourceNonValueTensorLiteralOpToResourceValueTensorLiteralOp(
|
||||
ResourceNonValueTensorLiteralOp op, PatternRewriter &rewriter) {
|
||||
Value valueTensor = rewriter.create<ResourceValueTensorLiteralOp>(
|
||||
op->getLoc(),
|
||||
op.getType().cast<NonValueTensorType>().getWithValueSemantics(),
|
||||
op.getSymName(), op.getValue());
|
||||
Value tensor =
|
||||
copyTensorToType(rewriter, op->getLoc(), op.getType(), valueTensor);
|
||||
rewriter.replaceOp(op, {tensor});
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||
void runOnOperation() override {
|
||||
|
@ -248,10 +261,13 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(context);
|
||||
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
|
||||
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
||||
patterns.add(
|
||||
reduceResourceNonValueTensorLiteralOpToResourceValueTensorLiteralOp);
|
||||
patterns.add<ReduceNonValueSemanticOps>(context);
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||
target.addIllegalOp<ResourceNonValueTensorLiteralOp>();
|
||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#endif
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -64,6 +65,7 @@ void mlir::torch::registerTorchConversionPasses() {
|
|||
|
||||
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm) {
|
||||
pm.addPass(createConvertTorchToMLProgramPass());
|
||||
// Lower to linalg + guards which is the input to codegen backends.
|
||||
// We do this first as it tends to involve pattern-matching against constants,
|
||||
// (e.g. dimensions which must be constant in a ranked programming model)
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -83,6 +84,8 @@ class VerifyLinalgOnTensorsBackendContractPass
|
|||
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
|
||||
opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
|
||||
opHasLegalTypes);
|
||||
|
||||
// ConstantOp is used for tensors and for scalars.
|
||||
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
||||
|
|
|
@ -252,6 +252,7 @@ def compile(model: torch.nn.Module,
|
|||
use_tracing: bool = False,
|
||||
ignore_traced_shapes=False,
|
||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||
use_external_references_if_numel_exceeds: Optional[int] = None,
|
||||
verbose: bool = False):
|
||||
"""Convert a PyTorch model to MLIR.
|
||||
|
||||
|
@ -349,6 +350,7 @@ def compile(model: torch.nn.Module,
|
|||
mb = ModuleBuilder()
|
||||
import_options = ImportOptions()
|
||||
import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes
|
||||
import_options.useExternalReferencesIfNumelExceeds = use_external_references_if_numel_exceeds
|
||||
try:
|
||||
original_stderr = sys.stderr
|
||||
sys.stderr = StringIO()
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
|
||||
#define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
|
||||
|
||||
#include "c10/util/Optional.h"
|
||||
|
||||
namespace torch_mlir {
|
||||
// Common import options across importers. We define this as a struct to avoid
|
||||
// an unstructured proliferation of different kinds of ways to control different
|
||||
|
@ -33,6 +35,9 @@ struct ImportOptions {
|
|||
// In that case, the appropriate shape information is provided via the type
|
||||
// bound annotations on the function arguments instead.
|
||||
bool ignoreExistingTensorShapesAndDtypes = false;
|
||||
// If this is set, then external constant references will be used when
|
||||
// importing tensors with numel exceeding the given threshold.
|
||||
c10::optional<unsigned> useExternalReferencesIfNumelExceeds = c10::nullopt;
|
||||
};
|
||||
} // namespace torch_mlir
|
||||
|
||||
|
|
|
@ -20,5 +20,7 @@ void torch_mlir::initImportOptionsBindings(py::module &m) {
|
|||
.def_readwrite("assumeTensorsHaveValueSemantics",
|
||||
&ImportOptions::assumeTensorsHaveValueSemantics)
|
||||
.def_readwrite("ignoreExistingTensorShapesAndDtypes",
|
||||
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
|
||||
&ImportOptions::ignoreExistingTensorShapesAndDtypes)
|
||||
.def_readwrite("useExternalReferencesIfNumelExceeds",
|
||||
&ImportOptions::useExternalReferencesIfNumelExceeds);
|
||||
}
|
||||
|
|
|
@ -372,30 +372,44 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
|
||||
MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||
assert(ivalue.isTensor() && "expected a tensor!");
|
||||
at::Tensor tensor = ivalue.toTensor();
|
||||
|
||||
// TODO: Can we do better?
|
||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
|
||||
// Import the bulk tensor representation.
|
||||
at::Tensor tensor = ivalue.toTensor().contiguous();
|
||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
|
||||
MlirOperation tensorOp;
|
||||
|
||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||
tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.vtensor.literal", loc,
|
||||
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements),
|
||||
MlirValue tensorReprValue;
|
||||
if (importOptions.useExternalReferencesIfNumelExceeds.has_value() &&
|
||||
tensor.numel() >
|
||||
importOptions.useExternalReferencesIfNumelExceeds.value()) {
|
||||
MlirType type = torchMlirTorchNonValueTensorTypeGet(
|
||||
context, tensor.sizes().size(), tensor.sizes().data(),
|
||||
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||
MlirAttribute symName = mlirFlatSymbolRefAttrGet(
|
||||
context, toMlirStringRef(
|
||||
c10::QualifiedName(attributeNameStack).qualifiedName()));
|
||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.tensor.resource.literal", loc, type,
|
||||
toMlirNamedAttribute("sym_name", symName),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||
} else {
|
||||
tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
MlirOperation tensorOp;
|
||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||
tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.vtensor.literal", loc,
|
||||
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
} else {
|
||||
tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
}
|
||||
tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||
}
|
||||
|
||||
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||
|
||||
// Construct the complete tensor value. This is trivial for most tensors, but
|
||||
// for quantized tensors (and probably sparse too, TBD) there is more for us
|
||||
// to do.
|
||||
|
|
Loading…
Reference in New Issue