[MLIR][TORCH] Add TorchToTMTensor pass

This pass is added to lower ops, which can not be lowered
via the TorchToLinalg pass, such as `torch.bincount` op.
This pass also uses torch-mlir's TMTensor Dialect to lower the
complex ops.

Also add torch.bincount op lowering with the help of TMTensor dialect

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/649/head snapshot-20220308.312
Vivek Khandelwal 2022-03-02 22:18:15 +05:30
parent b2952b12dd
commit 1a2a9e066f
17 changed files with 403 additions and 32 deletions

View File

@ -1402,3 +1402,53 @@ class HardTanhIntModule(torch.nn.Module):
def HardTanhIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-5, 5, (100, 100)))
class BincountModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
])
def forward(self, x):
return torch.ops.aten.bincount(x)
@register_test_case(module_factory=lambda: BincountModule())
def BincountModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10,)))
class BincountStaticSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([20], torch.int64, True),
])
def forward(self, x):
return torch.ops.aten.bincount(x)
@register_test_case(module_factory=lambda: BincountStaticSizeModule())
def BincountStaticSizeModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1000, (20,)))
class BincountMinlengthModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
])
def forward(self, x):
return torch.ops.aten.bincount(x, minlength=600)
@register_test_case(module_factory=lambda: BincountMinlengthModule())
def BincountMinlengthModule_basic(module, tu: TestUtils):
module.forward(torch.randint(500, (20,)))

View File

@ -114,4 +114,15 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
}
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "FuncOp"> {
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
let description = [{
Convert ATen ops to tmtensor/linalg ops.
This pass is similar to the TorchToLinalg pass; the difference is that this
pass also makes use of TMTensor Dialect, which the former one doesn't.
}];
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
}
#endif // TORCHMLIR_CONVERSION_PASSES

View File

@ -0,0 +1,21 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTMTensorPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H

View File

@ -2136,6 +2136,22 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
let assemblyFormat = "$grad_output `,` $self `,` $target `,` $weight `,` $reduction `,` $ignore_index `,` $total_weight attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `,` qualified(type($total_weight)) `->` qualified(type($result))";
}
def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::bincount : (Tensor, Tensor?, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorType:$weights,
Torch_IntType:$minlength
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $weights `,` $minlength attr-dict `:` qualified(type($self)) `,` qualified(type($weights)) `,` qualified(type($minlength)) `->` qualified(type($result))";
}
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -9,11 +9,11 @@
#ifndef TORCHMLIR_DIALECT_TORCH_UTILS_H
#define TORCHMLIR_DIALECT_TORCH_UTILS_H
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
namespace mlir {
namespace torch {
namespace Torch {
@ -22,6 +22,9 @@ int64_t toPositiveDim(int64_t dim, int64_t inputRank);
bool isValidDim(int64_t dim, int64_t inputRank);
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
torch_upstream::ScalarType getScalarTypeForType(Type type);
// Helper to convert a tensor to a specific scalar type.
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
Type dtype);
} // namespace Torch
} // namespace torch

View File

@ -2,6 +2,7 @@ add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToStd)
add_subdirectory(TorchToTosa)
add_subdirectory(TorchToTMTensor)
add_subdirectory(Utils)
# TODO: Automate this with add_torch_mlir_conversion_library.
@ -21,6 +22,7 @@ add_mlir_library(TorchMLIRConversionPasses
TorchMLIRTorchToSCF
TorchMLIRTorchToStd
TorchMLIRTorchToTosa
TorchMLIRTorchToTMTensor
TorchMLIRConversionUtils
#${torch_mlir_conversion_libs}
)

View File

@ -13,6 +13,7 @@
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
//===----------------------------------------------------------------------===//
// Pass registration

View File

@ -25,7 +25,6 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
@ -2254,12 +2253,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
Operation *op,
Type elementType) {
if (isa<AtenSumOp, AtenSumDimIntListOp>(op)) {
if (elementType.isa<mlir::FloatType>())
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
else if (elementType.isa<mlir::IntegerType>())
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
}
if (isa<AtenSumOp, AtenSumDimIntListOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
if (isa<AtenMaxOp>(op)) {
if (elementType.isa<mlir::FloatType>())

View File

@ -0,0 +1,22 @@
add_mlir_conversion_library(TorchMLIRTorchToTMTensor
TorchToTMTensor.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTMTensor
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRLinalg
MLIRMath
TorchMLIRTorchDialect
TorchMLIRTMTensorDialect
)
torch_mlir_target_includes(TorchMLIRTorchToTMTensor)

View File

@ -0,0 +1,221 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/MLIRContext.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::TMTensor;
// -----------------------------------------------------------------------------
// Patterns (as this grows, it should be organized into multiple files)
// -----------------------------------------------------------------------------
// This is going to eventually be O(#aten ops), which is in the 100s.
//
// Most of these patterns consist of:
// 1. Checking that the operand/result types and other static properties are
// good-enough to create a valid linalg op (such as operands being of
// ranks/dtypes acceptable to the linalg op).
// 2. Creating dynamic error guards, usually checking a predicate on the
// compatibility of operand shapes.
// 3. Creating init tensors for the computation op. Usually this involves
// reifying IR for a shape transfer function based on the operand shapes.
// 4. Creating a named linalg op to replace the original op.
//
// TODO: Use linalg OpDSL to autogenerate at least 1)/2)/3) such
// that these patterns become mostly mechanical associations of
// "aten.foo -> linalg.foo".
namespace {
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
// non-negative ints.
class ConvertAtenBincountOp : public OpConversionPattern<AtenBincountOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenBincountOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
MLIRContext *context = op->getContext();
TypeConverter *typeConverter = getTypeConverter();
Value input = adaptor.self();
Value torchTypeInput = op.self();
Value minlength = adaptor.minlength();
Value weights = adaptor.weights();
// TODO: Add a check to verify that the input tensor elements are all
// non-negative.
// Check whether the input is a 1-d tensor of integer type or not.
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
if (inputType.getRank() != 1 ||
!inputType.getElementType().isa<mlir::IntegerType>())
return rewriter.notifyMatchFailure(
op,
"Input tensor has to be a one-dimensional tensor of integer type.");
// Check whether the input tensor element type is i64 or not.
IntegerType inputIntegerType =
inputType.getElementType().cast<IntegerType>();
if (inputIntegerType.getWidth() != 64)
return rewriter.notifyMatchFailure(
op,
"Unimplemented: Integer width not equal to 64 are not supported.");
// TODO: Incorporate the weight argument.
if (!weights.getType().isa<mlir::torch::Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented, the weights operand is not incorporated.");
// Finding the maximum value in the input tensor.
SmallVector<int64_t> maxTensorSizes;
ValueTensorType maxTensorType = ValueTensorType::get(
context, llvm::makeArrayRef(maxTensorSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
Value maxTensor =
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
maxTensor = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(maxTensor.getType()),
maxTensor);
// `maxTensor` is a 0-d tensor, extracting its only element and
// storing it in `maxInput`.
Value maxInput = rewriter.create<tensor::ExtractOp>(loc, maxTensor);
// Creating a tm_tensor.scatter op with the following mapping:
// 1.) `input` tensor maps to the indices in scatter op. `input` is
// expanded from 1-d to 2-d, and its element type is set to i32 as required
// for the scatter op.
// 2.) `updates` is a 1-d dummy tensor with the size equivalent to the
// `input`.
// 3.) `bincount` a 1-d tensor maps to the original in scatter op
// with size equal to the max(max(input) + 1, minlength).
SmallVector<int64_t> expandedInputSizes{inputType.getShape()[0], 1};
ValueTensorType expandInputType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedInputSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value expandedInputTensor = rewriter.create<AtenUnsqueezeOp>(
loc, expandInputType, torchTypeInput, torchCstOne);
// Converting the input element type to i32.
Value indices = convertTensorToDtype(
rewriter, loc, expandedInputTensor,
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
indices = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
Type resultElemType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>()
.getElementType();
SmallVector<Value, 1> inputSizeDynamic =
getTensorSizesUntilDim(rewriter, loc, input, 0);
Value updatesTensor = rewriter.create<linalg::InitTensorOp>(
loc, getAsOpFoldResult(inputSizeDynamic), resultElemType);
Value constantZero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(resultElemType));
Value constantOne = rewriter.create<arith::ConstantIntOp>(
loc, 1, resultElemType.getIntOrFloatBitWidth());
// Bincount size = max(max(input) + 1, minlength)
Value maxInputPlusOne =
rewriter.create<arith::AddIOp>(loc, maxInput, constantOne);
Value bincountSize =
rewriter.create<arith::MaxSIOp>(loc, maxInputPlusOne, minlength);
bincountSize = castIntToIndex(rewriter, loc, bincountSize);
Value bincountTensor = createInitTensor(rewriter, loc, {bincountSize},
resultElemType, constantZero);
auto scatterOp = rewriter.create<TMTensor::ScatterOp>(
loc, bincountTensor.getType(), ValueRange{updatesTensor, indices},
ValueRange{bincountTensor},
/*unique_indices=*/false);
Region &scatterOpRegion = scatterOp.region();
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType},
{loc, loc});
auto blockArgs = scatterOpBlock.getArguments();
// Creating an add instruction inside the scatter op region to increment the
// frequency counter with one.
OpBuilder regionBuilder(scatterOpRegion);
Value add = regionBuilder.create<arith::AddIOp>(loc,
/*bincount=*/blockArgs[1],
constantOne);
regionBuilder.create<TMTensor::YieldOp>(loc, add);
rewriter.replaceOp(op, scatterOp->getResult(0));
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
namespace {
class ConvertTorchToTMTensor
: public ConvertTorchToTMTensorBase<ConvertTorchToTMTensor> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
registry.insert<StandardOpsDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<TMTensorDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
tensor::TensorDialect, arith::ArithmeticDialect,
Torch::TorchDialect, TMTensorDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context);
target.addIllegalOp<AtenBincountOp>();
patterns.add<ConvertAtenBincountOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::createConvertTorchToTMTensorPass() {
return std::make_unique<ConvertTorchToTMTensor>();
}

View File

@ -119,30 +119,8 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc,
return sub;
}
static Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype) {
int intType = (int)getScalarTypeForType(dtype);
return rewriter.create<ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(intType));
}
// Helper to convert a tensor to a specific scalar type.
static Value convertTensorToDtype(PatternRewriter &rewriter, Location loc,
Value input, Type dtype) {
BaseTensorType origType = input.getType().cast<BaseTensorType>();
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
// `convertIntVal` contains the corresponding integer for the dtype which is
// used by the aten.to.dtype op.
Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype);
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value converted = rewriter.create<AtenToDtypeOp>(
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
return converted;
}
// Helper to create a tensor filled with the given `scalar`. `scalar` would be
// converted to the element type of the given `resultType`.
// Helper to create a tensor filled with the given scalar. Scalar would be
// converted the to the element type of the given tensor type.
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
Type resultType, Value scalar, Value sizeList) {
BaseTensorType tensorType = resultType.cast<BaseTensorType>();

View File

@ -503,6 +503,8 @@ public:
return visitAtenConstantPadNdOp(constantPadNdOp, operands);
} else if (auto indexTensorOp = dyn_cast<AtenIndexTensorOp>(op)) {
return visitAtenIndexTensorOp(indexTensorOp, operands);
} else if (auto bincountOp = dyn_cast<AtenBincountOp>(op)) {
return visitAtenBincountOp(bincountOp, operands);
}
// Otherwise, this is an unknown operation. Just mark all results as
@ -671,6 +673,9 @@ private:
ChangeResult
visitAtenIndexTensorOp(AtenIndexTensorOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenBincountOp(AtenBincountOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
};
} // namespace
@ -1995,6 +2000,21 @@ ChangeResult TypeAnalyzer::visitAtenIndexTensorOp(
}
return getLatticeElement(op->getResult(0)).join(knowledge);
}
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
// non-negative ints. It returns a 1-d tensor of size max(max(input), length) of
// the type integer.
ChangeResult TypeAnalyzer::visitAtenBincountOp(
AtenBincountOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op.getContext(), 64, IntegerType::Signed);
knowledge.hasSizes = true;
knowledge.sizes.resize(1, kUnknownSize);
return getLatticeElement(op->getResult(0)).join(knowledge);
}
// -----------------------------------------------------------------------------
// Transforms.
// -----------------------------------------------------------------------------

View File

@ -46,6 +46,28 @@ ScalarType getScalarTypeForType(Type type) {
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype) {
int intType = (int)getScalarTypeForType(dtype);
return rewriter.create<ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(intType));
}
// Helper to convert a tensor to a specific scalar type.
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
Type dtype) {
BaseTensorType origType = input.getType().cast<BaseTensorType>();
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
// `convertIntVal` contains the corresponding integer for the dtype which is
// used by the aten.to.dtype op.
Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype);
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value converted = rewriter.create<AtenToDtypeOp>(
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
return converted;
}
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -23,6 +23,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
TorchMLIRTorchDialect
TorchMLIRTorchPasses
TorchMLIRTorchToLinalg
TorchMLIRTorchToTMTensor
TorchMLIRTorchToStd
TorchMLIRTorchToSCF
MLIRMemRefTransforms

View File

@ -19,6 +19,7 @@
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
@ -58,6 +59,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
// We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model)
// and those constants get somewhat obscured by TorchToStd.
pm.addNestedPass<FuncOp>(createConvertTorchToTMTensorPass());
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());

View File

@ -16,6 +16,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
@ -24,6 +26,8 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;
using namespace TMTensor;
namespace {
class VerifyLinalgOnTensorsBackendContractPass
@ -71,6 +75,7 @@ class VerifyLinalgOnTensorsBackendContractPass
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<AffineDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<TMTensorDialect>(opHasLegalTypes);
// ConstantOp is used for tensors and for scalars.
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);

View File

@ -565,6 +565,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::var : (Tensor, bool) -> (Tensor)")
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
emit ("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")