mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add TorchToTMTensor pass
This pass is added to lower ops, which can not be lowered via the TorchToLinalg pass, such as `torch.bincount` op. This pass also uses torch-mlir's TMTensor Dialect to lower the complex ops. Also add torch.bincount op lowering with the help of TMTensor dialect Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/649/head snapshot-20220308.312
parent
b2952b12dd
commit
1a2a9e066f
|
@ -1402,3 +1402,53 @@ class HardTanhIntModule(torch.nn.Module):
|
|||
def HardTanhIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-5, 5, (100, 100)))
|
||||
|
||||
|
||||
class BincountModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.bincount(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: BincountModule())
|
||||
def BincountModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10,)))
|
||||
|
||||
|
||||
class BincountStaticSizeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([20], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.bincount(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: BincountStaticSizeModule())
|
||||
def BincountStaticSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (20,)))
|
||||
|
||||
|
||||
class BincountMinlengthModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.bincount(x, minlength=600)
|
||||
|
||||
@register_test_case(module_factory=lambda: BincountMinlengthModule())
|
||||
def BincountMinlengthModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(500, (20,)))
|
||||
|
|
|
@ -114,4 +114,15 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
|
|||
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "FuncOp"> {
|
||||
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
|
||||
let description = [{
|
||||
Convert ATen ops to tmtensor/linalg ops.
|
||||
|
||||
This pass is similar to the TorchToLinalg pass; the difference is that this
|
||||
pass also makes use of TMTensor Dialect, which the former one doesn't.
|
||||
}];
|
||||
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_PASSES
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTMTensorPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
|
@ -2136,6 +2136,22 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
|
|||
let assemblyFormat = "$grad_output `,` $self `,` $target `,` $weight `,` $reduction `,` $ignore_index `,` $total_weight attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `,` qualified(type($total_weight)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::bincount : (Tensor, Tensor?, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorType:$weights,
|
||||
Torch_IntType:$minlength
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $weights `,` $minlength attr-dict `:` qualified(type($self)) `,` qualified(type($weights)) `,` qualified(type($minlength)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -9,11 +9,11 @@
|
|||
#ifndef TORCHMLIR_DIALECT_TORCH_UTILS_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_UTILS_H
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
@ -22,6 +22,9 @@ int64_t toPositiveDim(int64_t dim, int64_t inputRank);
|
|||
bool isValidDim(int64_t dim, int64_t inputRank);
|
||||
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
|
||||
torch_upstream::ScalarType getScalarTypeForType(Type type);
|
||||
// Helper to convert a tensor to a specific scalar type.
|
||||
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
||||
Type dtype);
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
|
|
|
@ -2,6 +2,7 @@ add_subdirectory(TorchToLinalg)
|
|||
add_subdirectory(TorchToSCF)
|
||||
add_subdirectory(TorchToStd)
|
||||
add_subdirectory(TorchToTosa)
|
||||
add_subdirectory(TorchToTMTensor)
|
||||
add_subdirectory(Utils)
|
||||
|
||||
# TODO: Automate this with add_torch_mlir_conversion_library.
|
||||
|
@ -21,6 +22,7 @@ add_mlir_library(TorchMLIRConversionPasses
|
|||
TorchMLIRTorchToSCF
|
||||
TorchMLIRTorchToStd
|
||||
TorchMLIRTorchToTosa
|
||||
TorchMLIRTorchToTMTensor
|
||||
TorchMLIRConversionUtils
|
||||
#${torch_mlir_conversion_libs}
|
||||
)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
@ -2254,12 +2253,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
|
||||
Operation *op,
|
||||
Type elementType) {
|
||||
if (isa<AtenSumOp, AtenSumDimIntListOp>(op)) {
|
||||
if (elementType.isa<mlir::FloatType>())
|
||||
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
else if (elementType.isa<mlir::IntegerType>())
|
||||
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
}
|
||||
if (isa<AtenSumOp, AtenSumDimIntListOp>(op))
|
||||
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
|
||||
if (isa<AtenMaxOp>(op)) {
|
||||
if (elementType.isa<mlir::FloatType>())
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchToTMTensor
|
||||
TorchToTMTensor.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTMTensor
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRLinalg
|
||||
MLIRMath
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRTMTensorDialect
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToTMTensor)
|
|
@ -0,0 +1,221 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::TorchConversion;
|
||||
using namespace mlir::torch::TMTensor;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Patterns (as this grows, it should be organized into multiple files)
|
||||
// -----------------------------------------------------------------------------
|
||||
// This is going to eventually be O(#aten ops), which is in the 100s.
|
||||
//
|
||||
// Most of these patterns consist of:
|
||||
// 1. Checking that the operand/result types and other static properties are
|
||||
// good-enough to create a valid linalg op (such as operands being of
|
||||
// ranks/dtypes acceptable to the linalg op).
|
||||
// 2. Creating dynamic error guards, usually checking a predicate on the
|
||||
// compatibility of operand shapes.
|
||||
// 3. Creating init tensors for the computation op. Usually this involves
|
||||
// reifying IR for a shape transfer function based on the operand shapes.
|
||||
// 4. Creating a named linalg op to replace the original op.
|
||||
//
|
||||
// TODO: Use linalg OpDSL to autogenerate at least 1)/2)/3) such
|
||||
// that these patterns become mostly mechanical associations of
|
||||
// "aten.foo -> linalg.foo".
|
||||
|
||||
namespace {
|
||||
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
|
||||
// non-negative ints.
|
||||
class ConvertAtenBincountOp : public OpConversionPattern<AtenBincountOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenBincountOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
Value input = adaptor.self();
|
||||
Value torchTypeInput = op.self();
|
||||
Value minlength = adaptor.minlength();
|
||||
Value weights = adaptor.weights();
|
||||
|
||||
// TODO: Add a check to verify that the input tensor elements are all
|
||||
// non-negative.
|
||||
// Check whether the input is a 1-d tensor of integer type or not.
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
if (inputType.getRank() != 1 ||
|
||||
!inputType.getElementType().isa<mlir::IntegerType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Input tensor has to be a one-dimensional tensor of integer type.");
|
||||
|
||||
// Check whether the input tensor element type is i64 or not.
|
||||
IntegerType inputIntegerType =
|
||||
inputType.getElementType().cast<IntegerType>();
|
||||
if (inputIntegerType.getWidth() != 64)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Unimplemented: Integer width not equal to 64 are not supported.");
|
||||
|
||||
// TODO: Incorporate the weight argument.
|
||||
if (!weights.getType().isa<mlir::torch::Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented, the weights operand is not incorporated.");
|
||||
|
||||
// Finding the maximum value in the input tensor.
|
||||
SmallVector<int64_t> maxTensorSizes;
|
||||
ValueTensorType maxTensorType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(maxTensorSizes),
|
||||
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
|
||||
Value maxTensor =
|
||||
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
|
||||
maxTensor = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(maxTensor.getType()),
|
||||
maxTensor);
|
||||
|
||||
// `maxTensor` is a 0-d tensor, extracting its only element and
|
||||
// storing it in `maxInput`.
|
||||
Value maxInput = rewriter.create<tensor::ExtractOp>(loc, maxTensor);
|
||||
|
||||
// Creating a tm_tensor.scatter op with the following mapping:
|
||||
// 1.) `input` tensor maps to the indices in scatter op. `input` is
|
||||
// expanded from 1-d to 2-d, and its element type is set to i32 as required
|
||||
// for the scatter op.
|
||||
// 2.) `updates` is a 1-d dummy tensor with the size equivalent to the
|
||||
// `input`.
|
||||
// 3.) `bincount` a 1-d tensor maps to the original in scatter op
|
||||
// with size equal to the max(max(input) + 1, minlength).
|
||||
SmallVector<int64_t> expandedInputSizes{inputType.getShape()[0], 1};
|
||||
ValueTensorType expandInputType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedInputSizes),
|
||||
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
|
||||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value expandedInputTensor = rewriter.create<AtenUnsqueezeOp>(
|
||||
loc, expandInputType, torchTypeInput, torchCstOne);
|
||||
|
||||
// Converting the input element type to i32.
|
||||
Value indices = convertTensorToDtype(
|
||||
rewriter, loc, expandedInputTensor,
|
||||
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
||||
indices = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||
|
||||
Type resultElemType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
|
||||
SmallVector<Value, 1> inputSizeDynamic =
|
||||
getTensorSizesUntilDim(rewriter, loc, input, 0);
|
||||
Value updatesTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, getAsOpFoldResult(inputSizeDynamic), resultElemType);
|
||||
|
||||
Value constantZero = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(resultElemType));
|
||||
Value constantOne = rewriter.create<arith::ConstantIntOp>(
|
||||
loc, 1, resultElemType.getIntOrFloatBitWidth());
|
||||
|
||||
// Bincount size = max(max(input) + 1, minlength)
|
||||
Value maxInputPlusOne =
|
||||
rewriter.create<arith::AddIOp>(loc, maxInput, constantOne);
|
||||
Value bincountSize =
|
||||
rewriter.create<arith::MaxSIOp>(loc, maxInputPlusOne, minlength);
|
||||
bincountSize = castIntToIndex(rewriter, loc, bincountSize);
|
||||
Value bincountTensor = createInitTensor(rewriter, loc, {bincountSize},
|
||||
resultElemType, constantZero);
|
||||
|
||||
auto scatterOp = rewriter.create<TMTensor::ScatterOp>(
|
||||
loc, bincountTensor.getType(), ValueRange{updatesTensor, indices},
|
||||
ValueRange{bincountTensor},
|
||||
/*unique_indices=*/false);
|
||||
|
||||
Region &scatterOpRegion = scatterOp.region();
|
||||
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
|
||||
scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType},
|
||||
{loc, loc});
|
||||
auto blockArgs = scatterOpBlock.getArguments();
|
||||
|
||||
// Creating an add instruction inside the scatter op region to increment the
|
||||
// frequency counter with one.
|
||||
OpBuilder regionBuilder(scatterOpRegion);
|
||||
Value add = regionBuilder.create<arith::AddIOp>(loc,
|
||||
/*bincount=*/blockArgs[1],
|
||||
constantOne);
|
||||
regionBuilder.create<TMTensor::YieldOp>(loc, add);
|
||||
rewriter.replaceOp(op, scatterOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
class ConvertTorchToTMTensor
|
||||
: public ConvertTorchToTMTensorBase<ConvertTorchToTMTensor> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect>();
|
||||
registry.insert<StandardOpsDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<TMTensorDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
tensor::TensorDialect, arith::ArithmeticDialect,
|
||||
Torch::TorchDialect, TMTensorDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
target.addIllegalOp<AtenBincountOp>();
|
||||
patterns.add<ConvertAtenBincountOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::torch::createConvertTorchToTMTensorPass() {
|
||||
return std::make_unique<ConvertTorchToTMTensor>();
|
||||
}
|
|
@ -119,30 +119,8 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
|||
return sub;
|
||||
}
|
||||
|
||||
static Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||
Type dtype) {
|
||||
int intType = (int)getScalarTypeForType(dtype);
|
||||
return rewriter.create<ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(intType));
|
||||
}
|
||||
|
||||
// Helper to convert a tensor to a specific scalar type.
|
||||
static Value convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
||||
Value input, Type dtype) {
|
||||
BaseTensorType origType = input.getType().cast<BaseTensorType>();
|
||||
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
|
||||
// `convertIntVal` contains the corresponding integer for the dtype which is
|
||||
// used by the aten.to.dtype op.
|
||||
Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype);
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value converted = rewriter.create<AtenToDtypeOp>(
|
||||
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
|
||||
return converted;
|
||||
}
|
||||
|
||||
// Helper to create a tensor filled with the given `scalar`. `scalar` would be
|
||||
// converted to the element type of the given `resultType`.
|
||||
// Helper to create a tensor filled with the given scalar. Scalar would be
|
||||
// converted the to the element type of the given tensor type.
|
||||
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
||||
Type resultType, Value scalar, Value sizeList) {
|
||||
BaseTensorType tensorType = resultType.cast<BaseTensorType>();
|
||||
|
|
|
@ -503,6 +503,8 @@ public:
|
|||
return visitAtenConstantPadNdOp(constantPadNdOp, operands);
|
||||
} else if (auto indexTensorOp = dyn_cast<AtenIndexTensorOp>(op)) {
|
||||
return visitAtenIndexTensorOp(indexTensorOp, operands);
|
||||
} else if (auto bincountOp = dyn_cast<AtenBincountOp>(op)) {
|
||||
return visitAtenBincountOp(bincountOp, operands);
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
|
@ -671,6 +673,9 @@ private:
|
|||
ChangeResult
|
||||
visitAtenIndexTensorOp(AtenIndexTensorOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenBincountOp(AtenBincountOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -1995,6 +2000,21 @@ ChangeResult TypeAnalyzer::visitAtenIndexTensorOp(
|
|||
}
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
|
||||
// non-negative ints. It returns a 1-d tensor of size max(max(input), length) of
|
||||
// the type integer.
|
||||
ChangeResult TypeAnalyzer::visitAtenBincountOp(
|
||||
AtenBincountOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = IntegerType::get(op.getContext(), 64, IntegerType::Signed);
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(1, kUnknownSize);
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Transforms.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -46,6 +46,28 @@ ScalarType getScalarTypeForType(Type type) {
|
|||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
||||
}
|
||||
|
||||
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||
Type dtype) {
|
||||
int intType = (int)getScalarTypeForType(dtype);
|
||||
return rewriter.create<ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(intType));
|
||||
}
|
||||
|
||||
// Helper to convert a tensor to a specific scalar type.
|
||||
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
||||
Type dtype) {
|
||||
BaseTensorType origType = input.getType().cast<BaseTensorType>();
|
||||
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
|
||||
// `convertIntVal` contains the corresponding integer for the dtype which is
|
||||
// used by the aten.to.dtype op.
|
||||
Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype);
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value converted = rewriter.create<AtenToDtypeOp>(
|
||||
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
|
||||
return converted;
|
||||
}
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -23,6 +23,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
|||
TorchMLIRTorchDialect
|
||||
TorchMLIRTorchPasses
|
||||
TorchMLIRTorchToLinalg
|
||||
TorchMLIRTorchToTMTensor
|
||||
TorchMLIRTorchToStd
|
||||
TorchMLIRTorchToSCF
|
||||
MLIRMemRefTransforms
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -58,6 +59,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
// We do this first as it tends to involve pattern-matching against constants,
|
||||
// (e.g. dimensions which must be constant in a ranked programming model)
|
||||
// and those constants get somewhat obscured by TorchToStd.
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToTMTensorPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
|
||||
|
@ -24,6 +26,8 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::TorchConversion;
|
||||
using namespace TMTensor;
|
||||
|
||||
|
||||
namespace {
|
||||
class VerifyLinalgOnTensorsBackendContractPass
|
||||
|
@ -71,6 +75,7 @@ class VerifyLinalgOnTensorsBackendContractPass
|
|||
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<AffineDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<TMTensorDialect>(opHasLegalTypes);
|
||||
|
||||
// ConstantOp is used for tensors and for scalars.
|
||||
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
||||
|
|
|
@ -565,6 +565,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
||||
emit ("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue