Jacque PR

cuda_f16
Abhishek-Varma 2022-12-20 08:06:59 +00:00 committed by Prashant Kumar
parent 957335f348
commit ead5c42a53
15 changed files with 211 additions and 16 deletions

View File

@ -26,6 +26,11 @@ def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
}
def ConvertTorchToMLProgram : Pass<"convert-torch-to-mlprogram", "ModuleOp"> {
let summary = "Convert Torch ops to MLProgram ops";
let constructor = "mlir::torch::createConvertTorchToMLProgramPass()";
}
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Linalg ops";
let description = [{

View File

@ -0,0 +1,23 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H
#define TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTorchToMLProgramPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMLPROGRAM_TORCHTOMLPROGRAM_H

View File

@ -302,6 +302,41 @@ def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [
let hasVerifier = 1;
}
def Torch_ResourceNonValueTensorLiteralOp : Torch_Op<"tensor.resource.literal", [
AllowsTypeRefinement,
AllowedInModuleInitializer
]> {
let summary = "Create a value of !torch.tensor type from an external literal";
let description = [{
}];
let arguments = (ins
FlatSymbolRefAttr:$sym_name,
ElementsAttr:$value
);
let results = (outs Torch_NonValueTensorType:$result);
let assemblyFormat = [{
`(` $sym_name `=` $value `)` attr-dict `:` qualified(type($result))
}];
}
def Torch_ResourceValueTensorLiteralOp : Torch_Op<"vtensor.resource.literal", [
AllowsTypeRefinement,
]> {
let summary = "Create a value of !torch.vtensor type from an external literal";
let description = [{
}];
let arguments = (ins
FlatSymbolRefAttr:$sym_name,
ElementsAttr:$value
);
let results = (outs Torch_ValueTensorType:$result);
let assemblyFormat = [{
`(` $sym_name `=` $value `)` attr-dict `:` qualified(type($result))
}];
}
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
Terminator,
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {

View File

@ -1,6 +1,7 @@
add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToArith)
add_subdirectory(TorchToMLProgram)
add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_MHLO)
add_subdirectory(TorchToMhlo)
@ -15,6 +16,7 @@ set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchToArith
TorchMLIRTorchToTosa
TorchMLIRTorchToTMTensor
TorchMLIRTorchToMLProgram
TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_MHLO)

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/IR/BuiltinOps.h"
namespace mlir {
namespace torch {

View File

@ -20,7 +20,7 @@
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,21 @@
add_mlir_conversion_library(TorchMLIRTorchToMLProgram
TorchToMLProgram.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMLProgram
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRFuncDialect
MLIRMLProgramDialect
TorchMLIRTorchDialect
)
torch_mlir_target_includes(TorchMLIRTorchToMLProgram)

View File

@ -0,0 +1,64 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
#include "../PassDetail.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
namespace {
class ConvertTorchToMLProgram
: public ConvertTorchToMLProgramBase<ConvertTorchToMLProgram> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<ml_program::MLProgramDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
auto module = getOperation();
MLIRContext *context = &getContext();
ConversionTarget dummyTarget(*context);
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(dummyTarget, typeConverter);
auto globalBuilder =
OpBuilder::atBlockBegin(&*module.getBodyRegion().begin());
module.walk([&](Torch::ResourceValueTensorLiteralOp op) {
auto type = typeConverter.convertType(op.getType());
globalBuilder.create<ml_program::GlobalOp>(
op.getLoc(), op.getSymNameAttr().getAttr(), type,
/*is_mutable=*/true, // Just to enable generator
/*value=*/op.getValue(),
globalBuilder.getStringAttr("public"));
OpBuilder builder(op);
auto loadConst = builder.create<ml_program::GlobalLoadOp>(
op.getLoc(), type, op.getSymNameAttr());
Value torchTensor = builder.create<TorchConversion::FromBuiltinTensorOp>(
op.getLoc(), op.getType(), loadConst);
op.replaceAllUsesWith(torchTensor);
op.erase();
});
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::createConvertTorchToMLProgramPass() {
return std::make_unique<ConvertTorchToMLProgram>();
}

View File

@ -240,6 +240,19 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op,
return success();
}
static LogicalResult
reduceResourceNonValueTensorLiteralOpToResourceValueTensorLiteralOp(
ResourceNonValueTensorLiteralOp op, PatternRewriter &rewriter) {
Value valueTensor = rewriter.create<ResourceValueTensorLiteralOp>(
op->getLoc(),
op.getType().cast<NonValueTensorType>().getWithValueSemantics(),
op.getSymName(), op.getValue());
Value tensor =
copyTensorToType(rewriter, op->getLoc(), op.getType(), valueTensor);
rewriter.replaceOp(op, {tensor});
return success();
}
namespace {
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
void runOnOperation() override {
@ -248,10 +261,13 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(context);
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
patterns.add(
reduceResourceNonValueTensorLiteralOpToResourceValueTensorLiteralOp);
patterns.add<ReduceNonValueSemanticOps>(context);
ConversionTarget target(*context);
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<ResourceNonValueTensorLiteralOp>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {

View File

@ -26,6 +26,7 @@
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchToMLProgram/TorchToMLProgram.h"
using namespace mlir;
using namespace mlir::torch;
@ -64,6 +65,7 @@ void mlir::torch::registerTorchConversionPasses() {
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm) {
pm.addPass(createConvertTorchToMLProgramPass());
// Lower to linalg + guards which is the input to codegen backends.
// We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model)

View File

@ -17,6 +17,7 @@
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
@ -83,6 +84,8 @@ class VerifyLinalgOnTensorsBackendContractPass
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
opHasLegalTypes);
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
opHasLegalTypes);
// ConstantOp is used for tensors and for scalars.
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);

View File

@ -252,6 +252,7 @@ def compile(model: torch.nn.Module,
use_tracing: bool = False,
ignore_traced_shapes=False,
backend_legal_ops: Optional[Sequence[str]] = None,
use_external_references_if_numel_exceeds: Optional[int] = None,
verbose: bool = False):
"""Convert a PyTorch model to MLIR.
@ -349,6 +350,7 @@ def compile(model: torch.nn.Module,
mb = ModuleBuilder()
import_options = ImportOptions()
import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes
import_options.useExternalReferencesIfNumelExceeds = use_external_references_if_numel_exceeds
try:
original_stderr = sys.stderr
sys.stderr = StringIO()

View File

@ -10,6 +10,8 @@
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
#define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
#include "c10/util/Optional.h"
namespace torch_mlir {
// Common import options across importers. We define this as a struct to avoid
// an unstructured proliferation of different kinds of ways to control different
@ -33,6 +35,9 @@ struct ImportOptions {
// In that case, the appropriate shape information is provided via the type
// bound annotations on the function arguments instead.
bool ignoreExistingTensorShapesAndDtypes = false;
// If this is set, then external constant references will be used when
// importing tensors with numel exceeding the given threshold.
c10::optional<unsigned> useExternalReferencesIfNumelExceeds = c10::nullopt;
};
} // namespace torch_mlir

View File

@ -20,5 +20,7 @@ void torch_mlir::initImportOptionsBindings(py::module &m) {
.def_readwrite("assumeTensorsHaveValueSemantics",
&ImportOptions::assumeTensorsHaveValueSemantics)
.def_readwrite("ignoreExistingTensorShapesAndDtypes",
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
&ImportOptions::ignoreExistingTensorShapesAndDtypes)
.def_readwrite("useExternalReferencesIfNumelExceeds",
&ImportOptions::useExternalReferencesIfNumelExceeds);
}

View File

@ -372,30 +372,44 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
assert(ivalue.isTensor() && "expected a tensor!");
at::Tensor tensor = ivalue.toTensor();
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
// Import the bulk tensor representation.
at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp;
if (importOptions.assumeTensorsHaveValueSemantics) {
tensorOp = createMlirOperationAtEnd(
importBlock, "torch.vtensor.literal", loc,
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements),
MlirValue tensorReprValue;
if (importOptions.useExternalReferencesIfNumelExceeds.has_value() &&
tensor.numel() >
importOptions.useExternalReferencesIfNumelExceeds.value()) {
MlirType type = torchMlirTorchNonValueTensorTypeGet(
context, tensor.sizes().size(), tensor.sizes().data(),
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
MlirAttribute symName = mlirFlatSymbolRefAttrGet(
context, toMlirStringRef(
c10::QualifiedName(attributeNameStack).qualifiedName()));
MlirOperation tensorOp = createMlirOperationAtEnd(
importBlock, "torch.tensor.resource.literal", loc, type,
toMlirNamedAttribute("sym_name", symName),
toMlirNamedAttribute("value", denseElements));
tensorReprValue = mlirOperationGetResult(tensorOp, 0);
} else {
tensorOp = createMlirOperationAtEnd(
importBlock, "torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
MlirOperation tensorOp;
if (importOptions.assumeTensorsHaveValueSemantics) {
tensorOp = createMlirOperationAtEnd(
importBlock, "torch.vtensor.literal", loc,
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
} else {
tensorOp = createMlirOperationAtEnd(
importBlock, "torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
}
tensorReprValue = mlirOperationGetResult(tensorOp, 0);
}
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
// Construct the complete tensor value. This is trivial for most tensors, but
// for quantized tensors (and probably sparse too, TBD) there is more for us
// to do.