Jacque PR

cuda_f16
Abhishek-Varma 2022-12-20 08:06:59 +00:00 committed by Prashant Kumar
parent 957335f348
commit ead5c42a53
15 changed files with 211 additions and 16 deletions

View File

@ -26,6 +26,11 @@ def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToSCFPass()"; let constructor = "mlir::torch::createConvertTorchToSCFPass()";
} }
def ConvertTorchToMLProgram : Pass<"convert-torch-to-mlprogram", "ModuleOp"> {
let summary = "Convert Torch ops to MLProgram ops";
let constructor = "mlir::torch::createConvertTorchToMLProgramPass()";
}
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Linalg ops"; let summary = "Convert recognized Torch ops to Linalg ops";
let description = [{ let description = [{

View File

@ -0,0 +1,23 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H
#define TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTorchToMLProgramPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H

View File

@ -302,6 +302,41 @@ def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [
let hasVerifier = 1; let hasVerifier = 1;
} }
def Torch_ResourceNonValueTensorLiteralOp : Torch_Op<"tensor.resource.literal", [
AllowsTypeRefinement,
AllowedInModuleInitializer
]> {
let summary = "Create a value of !torch.tensor type from an external literal";
let description = [{
}];
let arguments = (ins
FlatSymbolRefAttr:$sym_name,
ElementsAttr:$value
);
let results = (outs Torch_NonValueTensorType:$result);
let assemblyFormat = [{
`(` $sym_name `=` $value `)` attr-dict `:` qualified(type($result))
}];
}
def Torch_ResourceValueTensorLiteralOp : Torch_Op<"vtensor.resource.literal", [
AllowsTypeRefinement,
]> {
let summary = "Create a value of !torch.vtensor type from an external literal";
let description = [{
}];
let arguments = (ins
FlatSymbolRefAttr:$sym_name,
ElementsAttr:$value
);
let results = (outs Torch_ValueTensorType:$result);
let assemblyFormat = [{
`(` $sym_name `=` $value `)` attr-dict `:` qualified(type($result))
}];
}
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [ def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
Terminator, Terminator,
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> { HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {

View File

@ -1,6 +1,7 @@
add_subdirectory(TorchToLinalg) add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF) add_subdirectory(TorchToSCF)
add_subdirectory(TorchToArith) add_subdirectory(TorchToArith)
add_subdirectory(TorchToMLProgram)
add_subdirectory(TorchToTosa) add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_MHLO)
add_subdirectory(TorchToMhlo) add_subdirectory(TorchToMhlo)
@ -15,6 +16,7 @@ set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchToArith TorchMLIRTorchToArith
TorchMLIRTorchToTosa TorchMLIRTorchToTosa
TorchMLIRTorchToTMTensor TorchMLIRTorchToTMTensor
TorchMLIRTorchToMLProgram
TorchMLIRTorchConversionToMLProgram TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils) TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_MHLO)

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/IR/BuiltinOps.h"
namespace mlir { namespace mlir {
namespace torch { namespace torch {

View File

@ -20,7 +20,7 @@
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pass registration // Pass registration
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,21 @@
add_mlir_conversion_library(TorchMLIRTorchToMLProgram
TorchToMLProgram.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMLProgram
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRFuncDialect
MLIRMLProgramDialect
TorchMLIRTorchDialect
)
torch_mlir_target_includes(TorchMLIRTorchToMLProgram)

View File

@ -0,0 +1,64 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
#include "../PassDetail.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
namespace {
class ConvertTorchToMLProgram
: public ConvertTorchToMLProgramBase<ConvertTorchToMLProgram> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<ml_program::MLProgramDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
auto module = getOperation();
MLIRContext *context = &getContext();
ConversionTarget dummyTarget(*context);
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(dummyTarget, typeConverter);
auto globalBuilder =
OpBuilder::atBlockBegin(&*module.getBodyRegion().begin());
module.walk([&](Torch::ResourceValueTensorLiteralOp op) {
auto type = typeConverter.convertType(op.getType());
globalBuilder.create<ml_program::GlobalOp>(
op.getLoc(), op.getSymNameAttr().getAttr(), type,
/*is_mutable=*/true, // Just to enable generator
/*value=*/op.getValue(),
globalBuilder.getStringAttr("public"));
OpBuilder builder(op);
auto loadConst = builder.create<ml_program::GlobalLoadOp>(
op.getLoc(), type, op.getSymNameAttr());
Value torchTensor = builder.create<TorchConversion::FromBuiltinTensorOp>(
op.getLoc(), op.getType(), loadConst);
op.replaceAllUsesWith(torchTensor);
op.erase();
});
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::createConvertTorchToMLProgramPass() {
return std::make_unique<ConvertTorchToMLProgram>();
}

View File

@ -240,6 +240,19 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op,
return success(); return success();
} }
static LogicalResult
reduceResourceNonValueTensorLiteralOpToResourceValueTensorLiteralOp(
ResourceNonValueTensorLiteralOp op, PatternRewriter &rewriter) {
Value valueTensor = rewriter.create<ResourceValueTensorLiteralOp>(
op->getLoc(),
op.getType().cast<NonValueTensorType>().getWithValueSemantics(),
op.getSymName(), op.getValue());
Value tensor =
copyTensorToType(rewriter, op->getLoc(), op.getType(), valueTensor);
rewriter.replaceOp(op, {tensor});
return success();
}
namespace { namespace {
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> { class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
void runOnOperation() override { void runOnOperation() override {
@ -248,10 +261,13 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(context); patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(context);
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context); patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp); patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
patterns.add(
reduceResourceNonValueTensorLiteralOpToResourceValueTensorLiteralOp);
patterns.add<ReduceNonValueSemanticOps>(context); patterns.add<ReduceNonValueSemanticOps>(context);
ConversionTarget target(*context); ConversionTarget target(*context);
target.addIllegalOp<NonValueTensorLiteralOp>(); target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<ResourceNonValueTensorLiteralOp>();
target.addIllegalOp<AtenBernoulli_FloatOp>(); target.addIllegalOp<AtenBernoulli_FloatOp>();
target.markUnknownOpDynamicallyLegal([](Operation *op) { target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) { if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {

View File

@ -26,6 +26,7 @@
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#endif #endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -64,6 +65,7 @@ void mlir::torch::registerTorchConversionPasses() {
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm) { OpPassManager &pm) {
pm.addPass(createConvertTorchToMLProgramPass());
// Lower to linalg + guards which is the input to codegen backends. // Lower to linalg + guards which is the input to codegen backends.
// We do this first as it tends to involve pattern-matching against constants, // We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model) // (e.g. dimensions which must be constant in a ranked programming model)

View File

@ -17,6 +17,7 @@
#include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@ -83,6 +84,8 @@ class VerifyLinalgOnTensorsBackendContractPass
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes); target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>( target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
opHasLegalTypes); opHasLegalTypes);
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
opHasLegalTypes);
// ConstantOp is used for tensors and for scalars. // ConstantOp is used for tensors and for scalars.
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes); target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);

View File

@ -252,6 +252,7 @@ def compile(model: torch.nn.Module,
use_tracing: bool = False, use_tracing: bool = False,
ignore_traced_shapes=False, ignore_traced_shapes=False,
backend_legal_ops: Optional[Sequence[str]] = None, backend_legal_ops: Optional[Sequence[str]] = None,
use_external_references_if_numel_exceeds: Optional[int] = None,
verbose: bool = False): verbose: bool = False):
"""Convert a PyTorch model to MLIR. """Convert a PyTorch model to MLIR.
@ -349,6 +350,7 @@ def compile(model: torch.nn.Module,
mb = ModuleBuilder() mb = ModuleBuilder()
import_options = ImportOptions() import_options = ImportOptions()
import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes
import_options.useExternalReferencesIfNumelExceeds = use_external_references_if_numel_exceeds
try: try:
original_stderr = sys.stderr original_stderr = sys.stderr
sys.stderr = StringIO() sys.stderr = StringIO()

View File

@ -10,6 +10,8 @@
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H #ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
#define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H #define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
#include "c10/util/Optional.h"
namespace torch_mlir { namespace torch_mlir {
// Common import options across importers. We define this as a struct to avoid // Common import options across importers. We define this as a struct to avoid
// an unstructured proliferation of different kinds of ways to control different // an unstructured proliferation of different kinds of ways to control different
@ -33,6 +35,9 @@ struct ImportOptions {
// In that case, the appropriate shape information is provided via the type // In that case, the appropriate shape information is provided via the type
// bound annotations on the function arguments instead. // bound annotations on the function arguments instead.
bool ignoreExistingTensorShapesAndDtypes = false; bool ignoreExistingTensorShapesAndDtypes = false;
// If this is set, then external constant references will be used when
// importing tensors with numel exceeding the given threshold.
c10::optional<unsigned> useExternalReferencesIfNumelExceeds = c10::nullopt;
}; };
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -20,5 +20,7 @@ void torch_mlir::initImportOptionsBindings(py::module &m) {
.def_readwrite("assumeTensorsHaveValueSemantics", .def_readwrite("assumeTensorsHaveValueSemantics",
&ImportOptions::assumeTensorsHaveValueSemantics) &ImportOptions::assumeTensorsHaveValueSemantics)
.def_readwrite("ignoreExistingTensorShapesAndDtypes", .def_readwrite("ignoreExistingTensorShapesAndDtypes",
&ImportOptions::ignoreExistingTensorShapesAndDtypes); &ImportOptions::ignoreExistingTensorShapesAndDtypes)
.def_readwrite("useExternalReferencesIfNumelExceeds",
&ImportOptions::useExternalReferencesIfNumelExceeds);
} }

View File

@ -372,30 +372,44 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirValue IValueImporter::importTensor(c10::IValue ivalue) { MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
assert(ivalue.isTensor() && "expected a tensor!"); assert(ivalue.isTensor() && "expected a tensor!");
at::Tensor tensor = ivalue.toTensor();
// TODO: Can we do better? // TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context); MlirLocation loc = mlirLocationUnknownGet(context);
// Import the bulk tensor representation. // Import the bulk tensor representation.
at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc); MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp; MlirValue tensorReprValue;
if (importOptions.useExternalReferencesIfNumelExceeds.has_value() &&
if (importOptions.assumeTensorsHaveValueSemantics) { tensor.numel() >
tensorOp = createMlirOperationAtEnd( importOptions.useExternalReferencesIfNumelExceeds.value()) {
importBlock, "torch.vtensor.literal", loc, MlirType type = torchMlirTorchNonValueTensorTypeGet(
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements), context, tensor.sizes().size(), tensor.sizes().data(),
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
MlirAttribute symName = mlirFlatSymbolRefAttrGet(
context, toMlirStringRef(
c10::QualifiedName(attributeNameStack).qualifiedName()));
MlirOperation tensorOp = createMlirOperationAtEnd(
importBlock, "torch.tensor.resource.literal", loc, type,
toMlirNamedAttribute("sym_name", symName),
toMlirNamedAttribute("value", denseElements)); toMlirNamedAttribute("value", denseElements));
tensorReprValue = mlirOperationGetResult(tensorOp, 0);
} else { } else {
tensorOp = createMlirOperationAtEnd( MlirOperation tensorOp;
importBlock, "torch.tensor.literal", loc, if (importOptions.assumeTensorsHaveValueSemantics) {
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements), tensorOp = createMlirOperationAtEnd(
toMlirNamedAttribute("value", denseElements)); importBlock, "torch.vtensor.literal", loc,
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
} else {
tensorOp = createMlirOperationAtEnd(
importBlock, "torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
}
tensorReprValue = mlirOperationGetResult(tensorOp, 0);
} }
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
// Construct the complete tensor value. This is trivial for most tensors, but // Construct the complete tensor value. This is trivial for most tensors, but
// for quantized tensors (and probably sparse too, TBD) there is more for us // for quantized tensors (and probably sparse too, TBD) there is more for us
// to do. // to do.