mirror of https://github.com/llvm/torch-mlir
Jacque PR
parent
957335f348
commit
ead5c42a53
|
@ -26,6 +26,11 @@ def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
|
||||||
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
|
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ConvertTorchToMLProgram : Pass<"convert-torch-to-mlprogram", "ModuleOp"> {
|
||||||
|
let summary = "Convert Torch ops to MLProgram ops";
|
||||||
|
let constructor = "mlir::torch::createConvertTorchToMLProgramPass()";
|
||||||
|
}
|
||||||
|
|
||||||
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
|
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
|
||||||
let summary = "Convert recognized Torch ops to Linalg ops";
|
let summary = "Convert recognized Torch ops to Linalg ops";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H
|
||||||
|
#define TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertTorchToMLProgramPass();
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H
|
|
@ -302,6 +302,41 @@ def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_ResourceNonValueTensorLiteralOp : Torch_Op<"tensor.resource.literal", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
AllowedInModuleInitializer
|
||||||
|
]> {
|
||||||
|
let summary = "Create a value of !torch.tensor type from an external literal";
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
FlatSymbolRefAttr:$sym_name,
|
||||||
|
ElementsAttr:$value
|
||||||
|
);
|
||||||
|
let results = (outs Torch_NonValueTensorType:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $sym_name `=` $value `)` attr-dict `:` qualified(type($result))
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_ResourceValueTensorLiteralOp : Torch_Op<"vtensor.resource.literal", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
]> {
|
||||||
|
let summary = "Create a value of !torch.vtensor type from an external literal";
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
FlatSymbolRefAttr:$sym_name,
|
||||||
|
ElementsAttr:$value
|
||||||
|
);
|
||||||
|
let results = (outs Torch_ValueTensorType:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $sym_name `=` $value `)` attr-dict `:` qualified(type($result))
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
|
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
|
||||||
Terminator,
|
Terminator,
|
||||||
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {
|
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
add_subdirectory(TorchToLinalg)
|
add_subdirectory(TorchToLinalg)
|
||||||
add_subdirectory(TorchToSCF)
|
add_subdirectory(TorchToSCF)
|
||||||
add_subdirectory(TorchToArith)
|
add_subdirectory(TorchToArith)
|
||||||
|
add_subdirectory(TorchToMLProgram)
|
||||||
add_subdirectory(TorchToTosa)
|
add_subdirectory(TorchToTosa)
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
add_subdirectory(TorchToMhlo)
|
add_subdirectory(TorchToMhlo)
|
||||||
|
@ -15,6 +16,7 @@ set(linked_libs TorchMLIRTorchToLinalg
|
||||||
TorchMLIRTorchToArith
|
TorchMLIRTorchToArith
|
||||||
TorchMLIRTorchToTosa
|
TorchMLIRTorchToTosa
|
||||||
TorchMLIRTorchToTMTensor
|
TorchMLIRTorchToTMTensor
|
||||||
|
TorchMLIRTorchToMLProgram
|
||||||
TorchMLIRTorchConversionToMLProgram
|
TorchMLIRTorchConversionToMLProgram
|
||||||
TorchMLIRConversionUtils)
|
TorchMLIRConversionUtils)
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pass registration
|
// Pass registration
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
add_mlir_conversion_library(TorchMLIRTorchToMLProgram
|
||||||
|
TorchToMLProgram.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMLProgram
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TorchMLIRConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
MLIRFuncDialect
|
||||||
|
MLIRMLProgramDialect
|
||||||
|
TorchMLIRTorchDialect
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_mlir_target_includes(TorchMLIRTorchToMLProgram)
|
|
@ -0,0 +1,64 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||||
|
#include "mlir/Dialect/Traits.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertTorchToMLProgram
|
||||||
|
: public ConvertTorchToMLProgramBase<ConvertTorchToMLProgram> {
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<ml_program::MLProgramDialect>();
|
||||||
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto module = getOperation();
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ConversionTarget dummyTarget(*context);
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
TorchConversion::setupBackendTypeConversion(dummyTarget, typeConverter);
|
||||||
|
|
||||||
|
auto globalBuilder =
|
||||||
|
OpBuilder::atBlockBegin(&*module.getBodyRegion().begin());
|
||||||
|
module.walk([&](Torch::ResourceValueTensorLiteralOp op) {
|
||||||
|
auto type = typeConverter.convertType(op.getType());
|
||||||
|
globalBuilder.create<ml_program::GlobalOp>(
|
||||||
|
op.getLoc(), op.getSymNameAttr().getAttr(), type,
|
||||||
|
/*is_mutable=*/true, // Just to enable generator
|
||||||
|
/*value=*/op.getValue(),
|
||||||
|
globalBuilder.getStringAttr("public"));
|
||||||
|
OpBuilder builder(op);
|
||||||
|
auto loadConst = builder.create<ml_program::GlobalLoadOp>(
|
||||||
|
op.getLoc(), type, op.getSymNameAttr());
|
||||||
|
Value torchTensor = builder.create<TorchConversion::FromBuiltinTensorOp>(
|
||||||
|
op.getLoc(), op.getType(), loadConst);
|
||||||
|
op.replaceAllUsesWith(torchTensor);
|
||||||
|
op.erase();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::torch::createConvertTorchToMLProgramPass() {
|
||||||
|
return std::make_unique<ConvertTorchToMLProgram>();
|
||||||
|
}
|
|
@ -240,6 +240,19 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LogicalResult
|
||||||
|
reduceResourceNonValueTensorLiteralOpToResourceValueTensorLiteralOp(
|
||||||
|
ResourceNonValueTensorLiteralOp op, PatternRewriter &rewriter) {
|
||||||
|
Value valueTensor = rewriter.create<ResourceValueTensorLiteralOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
op.getType().cast<NonValueTensorType>().getWithValueSemantics(),
|
||||||
|
op.getSymName(), op.getValue());
|
||||||
|
Value tensor =
|
||||||
|
copyTensorToType(rewriter, op->getLoc(), op.getType(), valueTensor);
|
||||||
|
rewriter.replaceOp(op, {tensor});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
@ -248,10 +261,13 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||||
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(context);
|
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(context);
|
||||||
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
|
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
|
||||||
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
||||||
|
patterns.add(
|
||||||
|
reduceResourceNonValueTensorLiteralOpToResourceValueTensorLiteralOp);
|
||||||
patterns.add<ReduceNonValueSemanticOps>(context);
|
patterns.add<ReduceNonValueSemanticOps>(context);
|
||||||
|
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||||
|
target.addIllegalOp<ResourceNonValueTensorLiteralOp>();
|
||||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
#endif
|
#endif
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -64,6 +65,7 @@ void mlir::torch::registerTorchConversionPasses() {
|
||||||
|
|
||||||
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
OpPassManager &pm) {
|
OpPassManager &pm) {
|
||||||
|
pm.addPass(createConvertTorchToMLProgramPass());
|
||||||
// Lower to linalg + guards which is the input to codegen backends.
|
// Lower to linalg + guards which is the input to codegen backends.
|
||||||
// We do this first as it tends to involve pattern-matching against constants,
|
// We do this first as it tends to involve pattern-matching against constants,
|
||||||
// (e.g. dimensions which must be constant in a ranked programming model)
|
// (e.g. dimensions which must be constant in a ranked programming model)
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
@ -83,6 +84,8 @@ class VerifyLinalgOnTensorsBackendContractPass
|
||||||
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
|
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
|
||||||
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
|
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
|
||||||
opHasLegalTypes);
|
opHasLegalTypes);
|
||||||
|
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
|
||||||
|
opHasLegalTypes);
|
||||||
|
|
||||||
// ConstantOp is used for tensors and for scalars.
|
// ConstantOp is used for tensors and for scalars.
|
||||||
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
||||||
|
|
|
@ -252,6 +252,7 @@ def compile(model: torch.nn.Module,
|
||||||
use_tracing: bool = False,
|
use_tracing: bool = False,
|
||||||
ignore_traced_shapes=False,
|
ignore_traced_shapes=False,
|
||||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||||
|
use_external_references_if_numel_exceeds: Optional[int] = None,
|
||||||
verbose: bool = False):
|
verbose: bool = False):
|
||||||
"""Convert a PyTorch model to MLIR.
|
"""Convert a PyTorch model to MLIR.
|
||||||
|
|
||||||
|
@ -349,6 +350,7 @@ def compile(model: torch.nn.Module,
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
import_options = ImportOptions()
|
import_options = ImportOptions()
|
||||||
import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes
|
import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes
|
||||||
|
import_options.useExternalReferencesIfNumelExceeds = use_external_references_if_numel_exceeds
|
||||||
try:
|
try:
|
||||||
original_stderr = sys.stderr
|
original_stderr = sys.stderr
|
||||||
sys.stderr = StringIO()
|
sys.stderr = StringIO()
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
|
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
|
||||||
#define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
|
#define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
|
||||||
|
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
// Common import options across importers. We define this as a struct to avoid
|
// Common import options across importers. We define this as a struct to avoid
|
||||||
// an unstructured proliferation of different kinds of ways to control different
|
// an unstructured proliferation of different kinds of ways to control different
|
||||||
|
@ -33,6 +35,9 @@ struct ImportOptions {
|
||||||
// In that case, the appropriate shape information is provided via the type
|
// In that case, the appropriate shape information is provided via the type
|
||||||
// bound annotations on the function arguments instead.
|
// bound annotations on the function arguments instead.
|
||||||
bool ignoreExistingTensorShapesAndDtypes = false;
|
bool ignoreExistingTensorShapesAndDtypes = false;
|
||||||
|
// If this is set, then external constant references will be used when
|
||||||
|
// importing tensors with numel exceeding the given threshold.
|
||||||
|
c10::optional<unsigned> useExternalReferencesIfNumelExceeds = c10::nullopt;
|
||||||
};
|
};
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
|
|
|
@ -20,5 +20,7 @@ void torch_mlir::initImportOptionsBindings(py::module &m) {
|
||||||
.def_readwrite("assumeTensorsHaveValueSemantics",
|
.def_readwrite("assumeTensorsHaveValueSemantics",
|
||||||
&ImportOptions::assumeTensorsHaveValueSemantics)
|
&ImportOptions::assumeTensorsHaveValueSemantics)
|
||||||
.def_readwrite("ignoreExistingTensorShapesAndDtypes",
|
.def_readwrite("ignoreExistingTensorShapesAndDtypes",
|
||||||
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
|
&ImportOptions::ignoreExistingTensorShapesAndDtypes)
|
||||||
|
.def_readwrite("useExternalReferencesIfNumelExceeds",
|
||||||
|
&ImportOptions::useExternalReferencesIfNumelExceeds);
|
||||||
}
|
}
|
||||||
|
|
|
@ -372,16 +372,31 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
|
|
||||||
MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
assert(ivalue.isTensor() && "expected a tensor!");
|
assert(ivalue.isTensor() && "expected a tensor!");
|
||||||
|
at::Tensor tensor = ivalue.toTensor();
|
||||||
|
|
||||||
// TODO: Can we do better?
|
// TODO: Can we do better?
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
// Import the bulk tensor representation.
|
// Import the bulk tensor representation.
|
||||||
at::Tensor tensor = ivalue.toTensor().contiguous();
|
|
||||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||||
|
|
||||||
|
MlirValue tensorReprValue;
|
||||||
|
if (importOptions.useExternalReferencesIfNumelExceeds.has_value() &&
|
||||||
|
tensor.numel() >
|
||||||
|
importOptions.useExternalReferencesIfNumelExceeds.value()) {
|
||||||
|
MlirType type = torchMlirTorchNonValueTensorTypeGet(
|
||||||
|
context, tensor.sizes().size(), tensor.sizes().data(),
|
||||||
|
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||||
|
MlirAttribute symName = mlirFlatSymbolRefAttrGet(
|
||||||
|
context, toMlirStringRef(
|
||||||
|
c10::QualifiedName(attributeNameStack).qualifiedName()));
|
||||||
|
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||||
|
importBlock, "torch.tensor.resource.literal", loc, type,
|
||||||
|
toMlirNamedAttribute("sym_name", symName),
|
||||||
|
toMlirNamedAttribute("value", denseElements));
|
||||||
|
tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||||
|
} else {
|
||||||
MlirOperation tensorOp;
|
MlirOperation tensorOp;
|
||||||
|
|
||||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||||
tensorOp = createMlirOperationAtEnd(
|
tensorOp = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.vtensor.literal", loc,
|
importBlock, "torch.vtensor.literal", loc,
|
||||||
|
@ -393,9 +408,8 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
|
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
|
||||||
toMlirNamedAttribute("value", denseElements));
|
toMlirNamedAttribute("value", denseElements));
|
||||||
}
|
}
|
||||||
|
tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||||
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
}
|
||||||
|
|
||||||
// Construct the complete tensor value. This is trivial for most tensors, but
|
// Construct the complete tensor value. This is trivial for most tensors, but
|
||||||
// for quantized tensors (and probably sparse too, TBD) there is more for us
|
// for quantized tensors (and probably sparse too, TBD) there is more for us
|
||||||
// to do.
|
// to do.
|
||||||
|
|
Loading…
Reference in New Issue