mirror of https://github.com/llvm/torch-mlir
E2e support for aten.softmax.int and aten.embedding
- Added a DecomposeComplexOps pass to decompose complex torchOps. - Refactored `visitAtenArgmaxOp` and `visitAtenAnyDimOp` to `visitReductionAlongDimIntOp`. - Moved some helper functions into torch-mlir/Dialect/Torch/Utils/Utils.h to be shared by multiple files. - Added support for f64 tensor as argument and return types.pull/353/head snapshot-20211018.30
parent
0902438882
commit
a459e09ab7
|
@ -259,3 +259,125 @@ class GatherModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: GatherModule())
|
||||
def GatherModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))
|
||||
|
||||
class AddSizeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
# This is a workaround for not supporting scalar arguments.
|
||||
# TODO: pass in dim as an argument to the forward method when scalar
|
||||
# arguments are supported.
|
||||
return tensor.add(tensor, alpha=tensor.size(1))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AddSizeIntModule())
|
||||
def AddSizeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3))
|
||||
|
||||
|
||||
class AddSizeIntNegDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
# This is a workaround for not supporting scalar arguments.
|
||||
# TODO: pass in dim as an argument to the forward method when scalar
|
||||
# arguments are supported.
|
||||
return tensor.add(tensor, alpha=tensor.size(-2))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AddSizeIntNegDimModule())
|
||||
def AddSizeIntNegDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3))
|
||||
|
||||
class EmbeddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.embed = torch.nn.Embedding(num_embeddings=100,
|
||||
embedding_dim=50,
|
||||
padding_idx=4)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, indices):
|
||||
return self.embed.forward(indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: EmbeddingModule())
|
||||
def EmbeddingModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 3)))
|
||||
|
||||
|
||||
class SoftmaxIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.softmax = torch.nn.Softmax(2)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.softmax.forward(tensor)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SoftmaxIntModule())
|
||||
def SoftmaxIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
|
||||
class SoftmaxIntNegDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.softmax = torch.nn.Softmax(-2)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.softmax.forward(tensor)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SoftmaxIntNegDimModule())
|
||||
def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
|
||||
class SoftmaxIntArgTypeF64Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.softmax = torch.nn.Softmax(2)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.softmax.forward(tensor)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SoftmaxIntArgTypeF64Module())
|
||||
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).double())
|
||||
|
|
|
@ -937,6 +937,22 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
|||
let assemblyFormat = "$self `,` $kernel_size `,` $stride `,` $padding `,` $dilation `,` $ceil_mode attr-dict `:` type($self) `,` type($kernel_size) `,` type($stride) `,` type($padding) `,` type($dilation) `,` type($ceil_mode) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::softmax.int : (Tensor, int, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
TorchOptionalIntType:$dtype
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` type($self) `,` type($dim) `,` type($dtype) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -1591,6 +1607,7 @@ def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [
|
|||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $dim attr-dict `:` type($self) `,` type($dim) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenStackOp : Torch_Op<"aten.stack", [
|
||||
|
|
|
@ -819,6 +819,7 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
|
|||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `to` type($result)
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
|
||||
|
|
|
@ -54,6 +54,8 @@ std::unique_ptr<OperationPass<FuncOp>> createMaximizeValueSemanticsPass();
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createDecomposeComplexOpsPass();
|
||||
|
||||
} // namespace Torch
|
||||
|
||||
/// Registers all Torch transformation passes.
|
||||
|
|
|
@ -215,4 +215,20 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> {
|
||||
let summary = "Decompose complicated torch operations";
|
||||
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
|
||||
let description = [{
|
||||
Decompose torch operation that are losslessly represented as combinations of
|
||||
other operations, modulo appropropriate compiler fusion. Note that this pass
|
||||
is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about
|
||||
systematic reductions of a large number of ops at once, guided mostly by
|
||||
traits.
|
||||
|
||||
An example of the transformations done in this pass is:
|
||||
- convert aten.softmax to softmax(x, dim)
|
||||
=> tmp=exp(x); tmp / sum(tmp, dim, keepdim=True)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_TORCH_PASSES
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef TORCHMLIR_DIALECT_TORCH_UTILS_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_UTILS_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
int64_t toPositiveDim(int64_t dim, int64_t inputRank);
|
||||
bool isValidDim(int64_t dim, int64_t inputRank);
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_DIALECT_TORCH_UTILS_H
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
||||
|
@ -70,6 +71,35 @@ static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank
|
||||
static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
||||
Value inputRank) {
|
||||
assert(dim.getType().isa<IntegerType>() &&
|
||||
"dim arg of toPositiveDim must be integer type");
|
||||
Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank);
|
||||
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
||||
Value predDimGEZero =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
|
||||
Value dimInt = b.create<SelectOp>(loc, predDimGEZero, dim, dimAddInputRank);
|
||||
return dimInt;
|
||||
}
|
||||
|
||||
// Generate IR: assert(dim >= 0 && dim < inputRank)
|
||||
static void assertIsValidDim(OpBuilder &b, Location loc, Value dim,
|
||||
Value inputRank) {
|
||||
assert(dim.getType().isa<IntegerType>() &&
|
||||
"dim arg of assertIsValidDim must be integer type");
|
||||
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
||||
Value predGEZero =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
|
||||
b.create<AssertOp>(loc, predGEZero,
|
||||
b.getStringAttr("dim must be greater or equal to zero"));
|
||||
Value predLTInputRank =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, dim, inputRank);
|
||||
b.create<AssertOp>(loc, predLTInputRank,
|
||||
b.getStringAttr("dim must be smaller than inputRank"));
|
||||
}
|
||||
|
||||
// Hack to deal with the Torch list type arguments which is not supported end
|
||||
// to end. Constant values can be be extracted directly and non constant
|
||||
// list values are not supported.
|
||||
|
@ -459,6 +489,41 @@ static Value createLinalgPayloadCalculationForNormOps(
|
|||
return plusBias;
|
||||
}
|
||||
|
||||
static void createLinalgPayloadCalculationForGatherOps(
|
||||
OpBuilder &b, Location loc, Value input, int64_t inputRank, Value index,
|
||||
int64_t dim, int64_t outputRank) {
|
||||
SmallVector<Value> indices;
|
||||
for (int i = 0; i < inputRank; i++) {
|
||||
if (i == dim) {
|
||||
indices.push_back(castIntToIndex(b, loc, index));
|
||||
} else {
|
||||
// `outputRank` might be larger than `inputRank`. The `linalg::IndexOp`
|
||||
// takes in the dimension of the output. Add `inputDimOffset` to
|
||||
// related to the correct dimension of the output for dimension larger
|
||||
// than the given `dim`.
|
||||
int64_t inputDimOffset = i < dim ? 0 : outputRank - inputRank;
|
||||
indices.push_back(b.create<linalg::IndexOp>(loc, i + inputDimOffset));
|
||||
}
|
||||
}
|
||||
|
||||
// Assert index < input.sizes[dim]
|
||||
Value indexLTInputDim = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, index,
|
||||
castIndexToInt(b, loc, getDimOp(b, loc, input, dim)));
|
||||
b.create<AssertOp>(loc, indexLTInputDim,
|
||||
b.getStringAttr("index must be smaller than dim size"));
|
||||
|
||||
// Assert index >= 0
|
||||
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(index.getType()));
|
||||
Value indexGEThanZero =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, index, cst0);
|
||||
b.create<AssertOp>(loc, indexGEThanZero,
|
||||
b.getStringAttr("index must be larger or equal to 0"));
|
||||
|
||||
Value extract = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||
b.create<linalg::YieldOp>(loc, extract);
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
|
||||
public:
|
||||
|
@ -1027,6 +1092,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
ArrayRef<Value> operands) {
|
||||
if (isa<AtenTanhOp>(op))
|
||||
return b.create<math::TanhOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenExpOp>(op))
|
||||
return b.create<math::ExpOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenSigmoidOp>(op)) {
|
||||
Type elementType = payloadArgs[0].getType();
|
||||
auto one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
|
@ -1330,8 +1397,8 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenAddTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp>(
|
||||
op))
|
||||
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
|
||||
AtenExpOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1846,14 +1913,13 @@ public:
|
|||
for (auto i = 0; i < inputRank; i++)
|
||||
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
|
||||
for (auto i = 0; i < inputRank; i++) {
|
||||
if (i == dim0) {
|
||||
if (i == dim0)
|
||||
swapExprs.push_back(idExprs[dim1]);
|
||||
} else if (i == dim1) {
|
||||
else if (i == dim1)
|
||||
swapExprs.push_back(idExprs[dim0]);
|
||||
} else {
|
||||
else
|
||||
swapExprs.push_back(idExprs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::get(inputRank, 0, idExprs, op.getContext()),
|
||||
|
@ -1893,19 +1959,15 @@ public:
|
|||
|
||||
// Collect all the tensors to be concatenated.
|
||||
auto tensorList = op.tensors();
|
||||
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
|
||||
if (!listConstruct)
|
||||
SmallVector<Value> tensorsTorchType;
|
||||
if (!getListConstructElements(tensorList, tensorsTorchType))
|
||||
return op.emitError(
|
||||
"unimplemented: the tensor list is not from list construct");
|
||||
auto tensors = llvm::to_vector<4>(
|
||||
llvm::map_range(listConstruct.elements(), [&](Value tensor) -> Value {
|
||||
return typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, getTypeConverter()->convertType(tensor.getType()),
|
||||
tensor);
|
||||
}));
|
||||
auto tensors =
|
||||
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
|
||||
|
||||
RankedTensorType newResultType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
int rank = newResultType.getRank();
|
||||
SmallVector<Value> offsets, sizes, strides;
|
||||
sizes.reserve(rank);
|
||||
|
@ -1975,18 +2037,9 @@ public:
|
|||
auto genericOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, newResultTy, indices, result, affineMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
auto indexValue = args[0];
|
||||
Value indexOfDim = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), indexValue);
|
||||
SmallVector<Value> indices;
|
||||
for (int i = 0; i < rank; i++) {
|
||||
indices.push_back(i == dim
|
||||
? indexOfDim
|
||||
: rewriter.create<linalg::IndexOp>(loc, i));
|
||||
}
|
||||
Value extract =
|
||||
rewriter.create<tensor::ExtractOp>(loc, self, indices);
|
||||
rewriter.create<linalg::YieldOp>(loc, extract);
|
||||
auto index = args[0];
|
||||
createLinalgPayloadCalculationForGatherOps(b, loc, self, rank, index,
|
||||
dim, rank);
|
||||
});
|
||||
rewriter.replaceOp(op, genericOp.getResult(0));
|
||||
return success();
|
||||
|
@ -1994,6 +2047,93 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenEmbeddingOp : public OpConversionPattern<AtenEmbeddingOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenEmbeddingOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
AtenEmbeddingOp::Adaptor adaptor(operands);
|
||||
Value weight = adaptor.weight();
|
||||
Value indices = adaptor.indices();
|
||||
RankedTensorType newResultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
if (weightTy.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
|
||||
Value embeddingDim = getDimOp(rewriter, loc, weight, 1);
|
||||
Type elemTy = weightTy.getElementType();
|
||||
|
||||
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, indices);
|
||||
sizes.push_back(embeddingDim);
|
||||
int64_t resultRank = sizes.size();
|
||||
|
||||
auto indicesTy = weight.getType().cast<RankedTensorType>();
|
||||
int64_t indicesRank = indicesTy.getRank();
|
||||
SmallVector<AffineExpr> indicesExprs;
|
||||
for (int i = 0; i < indicesRank; i++)
|
||||
indicesExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
auto indicesAffineMap = AffineMap::get(
|
||||
/*dimCount=*/resultRank,
|
||||
/*symbolCount=*/0, indicesExprs, op->getContext());
|
||||
SmallVector<AffineMap, 2> indexingMaps = {
|
||||
indicesAffineMap,
|
||||
rewriter.getMultiDimIdentityMap(resultRank),
|
||||
};
|
||||
SmallVector<StringRef> iteratorTypes(sizes.size(),
|
||||
getParallelIteratorTypeName());
|
||||
Value initTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, sizes, elemTy);
|
||||
Value embeddingResult =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor.getType(), indices, initTensor,
|
||||
/*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value index = args[0];
|
||||
createLinalgPayloadCalculationForGatherOps(
|
||||
b, loc, weight, weightTy.getRank(), index, /*dim=*/0,
|
||||
resultRank);
|
||||
})
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
|
||||
embeddingResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenSizeIntOp : public OpConversionPattern<AtenSizeIntOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenSizeIntOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
AtenSizeIntOp::Adaptor adaptor(operands);
|
||||
Value self = adaptor.self();
|
||||
Value dim = adaptor.dim();
|
||||
auto type = self.getType().cast<RankedTensorType>();
|
||||
Value inputRank = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(type.getRank()));
|
||||
Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank);
|
||||
assertIsValidDim(rewriter, loc, dimPositive, inputRank);
|
||||
Value size = rewriter.create<tensor::DimOp>(
|
||||
loc, adaptor.self(), castIntToIndex(rewriter, loc, dimPositive));
|
||||
rewriter.replaceOp(op, castIndexToInt(rewriter, loc, size));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -2057,6 +2197,10 @@ public:
|
|||
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSizeIntOp>();
|
||||
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenEmbeddingOp>();
|
||||
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Utils)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -505,6 +506,29 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSizeIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto type = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||
if (!type || !type.hasSizes())
|
||||
return nullptr;
|
||||
|
||||
int64_t inputRank = type.getSizes().size();
|
||||
int64_t dim;
|
||||
if (!matchPattern(this->dim(), m_TorchConstantInt(&dim)))
|
||||
return nullptr;
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank))
|
||||
return nullptr;
|
||||
|
||||
if (type.getSizes()[dim] == kUnknownSize)
|
||||
return nullptr;
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
||||
type.getSizes()[dim]);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenGtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -655,6 +679,19 @@ bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
|||
outputs[0].cast<BaseTensorType>());
|
||||
}
|
||||
|
||||
void TensorStaticInfoCastOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) {
|
||||
auto reverseCast =
|
||||
op.operand().getDefiningOp<Torch::TensorStaticInfoCastOp>();
|
||||
if (!reverseCast || reverseCast.operand().getType() != op.getType())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, reverseCast.operand());
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CopyToNonValueTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
add_mlir_library(TorchMLIRTorchPasses
|
||||
AdjustCallingConventions.cpp
|
||||
DecomposeComplexOps.cpp
|
||||
Passes.cpp
|
||||
GlobalizeObjectGraph.cpp
|
||||
InlineGlobalSlots.cpp
|
||||
|
@ -23,6 +24,7 @@ add_mlir_library(TorchMLIRTorchPasses
|
|||
MLIRPass
|
||||
MLIRTransforms
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRTorchUtils
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchPasses)
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// Decompose softmax into: exp(x) / sum(exp(x))
|
||||
namespace {
|
||||
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.self();
|
||||
Value dim = op.dim();
|
||||
if (!op.dtype().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented non-None dtype for softmax");
|
||||
|
||||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
// exp(x)
|
||||
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
|
||||
|
||||
// sum(exp(x))
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(dim.getType()), dim);
|
||||
Value keepDim = rewriter.create<ConstantBoolOp>(loc, true);
|
||||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||||
SmallVector<int64_t> sizes;
|
||||
int64_t dimInt;
|
||||
if (tensorType.hasSizes()) {
|
||||
ArrayRef<int64_t> inputShape = tensorType.getSizes();
|
||||
int64_t inputRank = inputShape.size();
|
||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
sizes.append(inputShape.begin(), inputShape.end());
|
||||
sizes[dimInt] = 1;
|
||||
} else {
|
||||
sizes.resize(inputRank, kUnknownSize);
|
||||
}
|
||||
}
|
||||
Type resultType = tensorType.getWithSizesAndDtype(
|
||||
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
|
||||
: llvm::makeArrayRef(sizes),
|
||||
tensorType.getDtype());
|
||||
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, exp,
|
||||
dimList, keepDim, dtype);
|
||||
// exp(x) / sum(exp(x))
|
||||
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
|
||||
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
||||
result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect>();
|
||||
|
||||
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::torch::Torch::createDecomposeComplexOpsPass() {
|
||||
return std::make_unique<DecomposeComplexOpsPass>();
|
||||
}
|
|
@ -90,7 +90,7 @@ public:
|
|||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||
AtenTransposeIntOp>(op)) {
|
||||
AtenTransposeIntOp, TensorStaticInfoCastOp>(op)) {
|
||||
viewLikeOps.push_back(op);
|
||||
llvm::append_range(workList, op->getResult(0).getUsers());
|
||||
} else {
|
||||
|
|
|
@ -119,9 +119,17 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
|
||||
// Do shape and dtype refinement.
|
||||
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
|
||||
|
||||
// Propagate to ABI return types the shape/dtype information discovered by
|
||||
// the previous pass. Doing this is ABI-compatible for our backends.
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// This can fold away some branches given the information got from
|
||||
// RefineTypes before doing maximize value sematics which only works with
|
||||
// basic blocks.
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
|
||||
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||
|
||||
|
@ -134,6 +142,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
// only-used-in-training operations on `torch.global_slot`'s.
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
pm.addNestedPass<FuncOp>(Torch::createDecomposeComplexOpsPass());
|
||||
|
||||
// TODO: VerifyTorchBackendContractPass.
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -263,7 +264,8 @@ public:
|
|||
} else if (auto arangeStart = dyn_cast<AtenArangeStartOp>(op)) {
|
||||
return visitAtenArangeStartOp(arangeStart);
|
||||
} else if (auto sum = dyn_cast<AtenSumOp>(op)) {
|
||||
return visitReductionAlongAllDimsOp(sum, operands);
|
||||
Type dtype = operands[0]->getValue().dtype;
|
||||
return visitReductionAlongAllDimsOp(sum, dtype, operands);
|
||||
} else if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||
return visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
|
||||
sumDimIntList.keepdim(), operands);
|
||||
|
@ -271,9 +273,17 @@ public:
|
|||
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
|
||||
meanDim.keepdim(), operands);
|
||||
} else if (auto argmax = dyn_cast<AtenArgmaxOp>(op)) {
|
||||
return visitAtenArgmaxOp(argmax, operands);
|
||||
Value dim = argmax.dim();
|
||||
Type dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
if (dim.getType().isa<Torch::NoneType>())
|
||||
return visitReductionAlongAllDimsOp(op, dtype, operands);
|
||||
if (dim.getType().isa<Torch::IntType>())
|
||||
return visitReductionAlongDimIntOp(argmax, argmax.dim(),
|
||||
argmax.keepdim(), dtype, operands);
|
||||
} else if (auto anyDim = dyn_cast<AtenAnyDimOp>(op)) {
|
||||
return visitAtenAnyDimOp(anyDim, operands);
|
||||
Type dtype = operands[0]->getValue().dtype;
|
||||
return visitReductionAlongDimIntOp(anyDim, anyDim.dim(), anyDim.keepdim(),
|
||||
dtype, operands);
|
||||
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
|
||||
return visitReshapeLikeOp(view, operands);
|
||||
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) {
|
||||
|
@ -353,6 +363,8 @@ public:
|
|||
return visitAtenEmbeddingOp(embedding, operands);
|
||||
} else if (auto bmm = dyn_cast<AtenBmmOp>(op)) {
|
||||
return visitAtenBmmOp(bmm, operands);
|
||||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxIntOp(softmaxIntOp, operands);
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
|
@ -394,15 +406,13 @@ private:
|
|||
ChangeResult visitAtenArangeStartOp(AtenArangeStartOp op);
|
||||
ChangeResult visitAtenArangeOp(AtenArangeOp op);
|
||||
ChangeResult visitReductionAlongAllDimsOp(
|
||||
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
Operation *op, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult visitReductionAlongDimIntListOp(
|
||||
Operation *op, Value dim, Value keepdim,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenArgmaxOp(AtenArgmaxOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenAnyDimOp(AtenAnyDimOp op,
|
||||
ChangeResult visitReductionAlongDimIntOp(
|
||||
Operation *op, Value dim, Value keepdim, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
template <typename OpTy>
|
||||
ChangeResult
|
||||
|
@ -448,27 +458,34 @@ private:
|
|||
ChangeResult
|
||||
visitAtenBmmOp(AtenBmmOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static int64_t toPositiveDim(int64_t dim, int64_t inputRank) {
|
||||
return dim >= 0 ? dim : dim + inputRank;
|
||||
}
|
||||
|
||||
static bool isValidDim(int64_t dim, int64_t inputRank) {
|
||||
return dim >= 0 && dim < inputRank;
|
||||
// Get the MLIR type of the tensor dtype given the dtype integer value and the
|
||||
// input dtype. When DType is None the type is inferred from the input dtype.
|
||||
static void fillInDTypeGivenDTypeIntAndInputDType(MLIRContext *context,
|
||||
ValueKnowledge &knowledge,
|
||||
Value dtype,
|
||||
Type inputDType) {
|
||||
int64_t dtypeInt;
|
||||
if (dtype.getType().isa<Torch::NoneType>())
|
||||
knowledge.dtype = inputDType;
|
||||
else if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
knowledge.dtype = getTypeFromDTypeInteger(context, dtypeInt);
|
||||
}
|
||||
|
||||
// Get the MLIR type of the tensor dtype given the dtype integer value and data
|
||||
// type. When DType is None the type is inferred from the data type.
|
||||
// type of torch type. When DType is None the type is inferred from the data
|
||||
// type.
|
||||
static void fillInDTypeGivenDTypeAndDataType(MLIRContext *context,
|
||||
ValueKnowledge &knowledge,
|
||||
Value dtype, Type dataType) {
|
||||
int64_t dtypeInt;
|
||||
if (dtype.getType().isa<Torch::NoneType>())
|
||||
knowledge.dtype = getDTypeFromTorchType(context, dataType);
|
||||
else if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
knowledge.dtype = getTypeFromDTypeInteger(context, dtypeInt);
|
||||
Type dtypeFromDataType = getDTypeFromTorchType(context, dataType);
|
||||
fillInDTypeGivenDTypeIntAndInputDType(context, knowledge, dtype,
|
||||
dtypeFromDataType);
|
||||
}
|
||||
|
||||
static void fillInSizesGivenSizesList(ValueKnowledge &knowledge, Value sizes) {
|
||||
|
@ -718,10 +735,10 @@ ChangeResult TypeAnalyzer::visitAtenArangeOp(AtenArangeOp op) {
|
|||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp(
|
||||
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
Operation *op, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
knowledge.dtype = dtype;
|
||||
// Reduction along all dims always results in a tensor of rank zero,
|
||||
// which is represented by the default empty `knowledge.sizes` vector
|
||||
knowledge.hasSizes = true;
|
||||
|
@ -764,53 +781,28 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
|
|||
}
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
ChangeResult TypeAnalyzer::visitAtenArgmaxOp(
|
||||
AtenArgmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
int64_t dim;
|
||||
bool keepDim;
|
||||
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
int64_t inputRank = input.sizes.size();
|
||||
knowledge.hasSizes = true;
|
||||
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
||||
knowledge.sizes = input.sizes;
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (isValidDim(dim, inputRank)) {
|
||||
if (keepDim)
|
||||
knowledge.sizes[dim] = 1;
|
||||
else
|
||||
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
|
||||
}
|
||||
} else if (op.dim().getType().isa<IntegerType>())
|
||||
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1,
|
||||
kUnknownSize);
|
||||
}
|
||||
// If dim is no kind of Integer, keepDim is ignored,
|
||||
// and the result will bea rank-0 tensor
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
|
||||
AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
|
||||
Operation *op, Value dim, Value keepdim, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
assert(dim.getType().isa<Torch::IntType>() && "dim must be int type");
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
int64_t dim;
|
||||
knowledge.dtype = dtype;
|
||||
int64_t dimInt;
|
||||
bool keepDim;
|
||||
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) {
|
||||
int64_t inputRank = input.sizes.size();
|
||||
knowledge.hasSizes = true;
|
||||
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||
knowledge.sizes = input.sizes;
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (isValidDim(dim, inputRank)) {
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (isValidDim(dimInt, inputRank)) {
|
||||
if (keepDim)
|
||||
knowledge.sizes[dim] = 1;
|
||||
knowledge.sizes[dimInt] = 1;
|
||||
else
|
||||
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
|
||||
knowledge.sizes.erase(knowledge.sizes.begin() + dimInt);
|
||||
}
|
||||
} else {
|
||||
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, kUnknownSize);
|
||||
|
@ -1081,6 +1073,19 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
|
|||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenSoftmaxIntOp(
|
||||
AtenSoftmaxIntOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto dtype = op.dtype();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = input.hasSizes;
|
||||
knowledge.sizes = input.sizes;
|
||||
fillInDTypeGivenDTypeIntAndInputDType(op->getContext(), knowledge, dtype,
|
||||
input.dtype);
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenBmmOp(
|
||||
AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
add_mlir_dialect_library(TorchMLIRTorchUtils
|
||||
Utils.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Utils
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
int64_t toPositiveDim(int64_t dim, int64_t inputRank) {
|
||||
return dim >= 0 ? dim : dim + inputRank;
|
||||
}
|
||||
|
||||
bool isValidDim(int64_t dim, int64_t inputRank) {
|
||||
return dim >= 0 && dim < inputRank;
|
||||
}
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
|
@ -45,6 +45,8 @@ static bool isArgMemRefTypeValid(Type type) {
|
|||
Type elemTy = memRefType.getElementType();
|
||||
if (elemTy.isa<Float32Type>()) {
|
||||
return true;
|
||||
} else if (elemTy.isa<Float64Type>()) {
|
||||
return true;
|
||||
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
|
||||
if (integerTy.isSignlessInteger(64))
|
||||
return true;
|
||||
|
@ -81,7 +83,8 @@ static LogicalResult mungeFunction(
|
|||
for (auto arg : func.getArguments()) {
|
||||
auto type = arg.getType();
|
||||
if (!isArgMemRefTypeValid(type))
|
||||
return emitError(arg.getLoc(), "argument must be a memref of f32 or i64");
|
||||
return emitError(arg.getLoc(),
|
||||
"argument must be a memref of f32, f64, i64");
|
||||
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
|
||||
arg.replaceAllUsesExcept(cast, cast);
|
||||
arg.setType(getAbiTypeForMemRef(type));
|
||||
|
@ -91,12 +94,17 @@ static LogicalResult mungeFunction(
|
|||
SmallVector<Operation *> toErase;
|
||||
bool hadError = false;
|
||||
func.walk([&](ReturnOp op) {
|
||||
auto returnType =
|
||||
op.getOperandTypes()[0].dyn_cast<MemRefType>().getElementType();
|
||||
auto memRefType = op.getOperandTypes()[0].dyn_cast<MemRefType>();
|
||||
if (!memRefType) {
|
||||
hadError = true;
|
||||
op.emitError("return value must be memref type");
|
||||
return;
|
||||
}
|
||||
auto returnType = memRefType.getElementType();
|
||||
auto it = consumeFuncReturnFuncs.find(returnType);
|
||||
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
|
||||
hadError = true;
|
||||
op.emitError("must have one return value: a memref of f32 or i64");
|
||||
op.emitError("must have one return value: a memref of f32, i64 or f64");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -126,26 +134,27 @@ class MungeCallingConventions
|
|||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
OpBuilder b(module.getBodyRegion());
|
||||
auto consumeFuncReturnInt64Func = b.create<FuncOp>(
|
||||
module.getLoc(), "refbackend_consume_int64_func_return",
|
||||
FunctionType::get(
|
||||
module.getContext(),
|
||||
UnrankedMemRefType::get(b.getI64Type(), /*memorySpace=*/0), {}),
|
||||
b.getStringAttr("private"));
|
||||
auto consumeFuncReturnFloat32Func = b.create<FuncOp>(
|
||||
module.getLoc(), "refbackend_consume_float32_func_return",
|
||||
FunctionType::get(
|
||||
module.getContext(),
|
||||
UnrankedMemRefType::get(b.getF32Type(), /*memorySpace=*/0), {}),
|
||||
b.getStringAttr("private"));
|
||||
addEmitCInterfaceAttr(consumeFuncReturnInt64Func);
|
||||
addEmitCInterfaceAttr(consumeFuncReturnFloat32Func);
|
||||
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
|
||||
consumeFuncReturnFuncs[b.getF32Type()] = consumeFuncReturnFloat32Func;
|
||||
consumeFuncReturnFuncs[b.getI64Type()] = consumeFuncReturnInt64Func;
|
||||
DenseSet<FuncOp> consumeFuncReturnFuncsSet;
|
||||
auto createConsumeFuncReturnFunc = [&](Type elemTy, std::string funcName) {
|
||||
auto consumeFuncReturnFunc = b.create<FuncOp>(
|
||||
module.getLoc(), funcName,
|
||||
FunctionType::get(module.getContext(),
|
||||
UnrankedMemRefType::get(elemTy, /*memorySpace=*/0),
|
||||
{}),
|
||||
b.getStringAttr("private"));
|
||||
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
||||
consumeFuncReturnFuncs[elemTy] = consumeFuncReturnFunc;
|
||||
consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc);
|
||||
};
|
||||
createConsumeFuncReturnFunc(b.getI64Type(),
|
||||
"refbackend_consume_int64_func_return");
|
||||
createConsumeFuncReturnFunc(b.getF32Type(),
|
||||
"refbackend_consume_float32_func_return");
|
||||
createConsumeFuncReturnFunc(b.getF64Type(),
|
||||
"refbackend_consume_float64_func_return");
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (func == consumeFuncReturnInt64Func ||
|
||||
func == consumeFuncReturnFloat32Func)
|
||||
if (consumeFuncReturnFuncsSet.contains(func))
|
||||
continue;
|
||||
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -482,6 +482,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit(
|
||||
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
|
||||
|
@ -525,7 +528,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
|
||||
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::size.int : (Tensor, int) -> (int)")
|
||||
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
|
||||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||
emit("aten::sum : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
|
||||
|
|
|
@ -24,13 +24,8 @@ __all__ = [
|
|||
|
||||
|
||||
def checkArgTypeIsSupported(ty):
|
||||
if ty == np.float32:
|
||||
return
|
||||
elif ty == np.int64:
|
||||
return
|
||||
assert False, "Only tensor argument of float32 and int64 are supported but got " + str(
|
||||
ty)
|
||||
|
||||
SUPPORTED = [np.float32, np.float64, np.int64]
|
||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
|
||||
|
||||
class RefBackendInvoker:
|
||||
def __init__(self, module):
|
||||
|
@ -45,12 +40,19 @@ class RefBackendInvoker:
|
|||
def consume_f32_return(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.float32)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_f64_return(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.float64)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_int64_func_return",
|
||||
consume_i64_return)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_float32_func_return",
|
||||
consume_f32_return)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_float64_func_return",
|
||||
consume_f64_return)
|
||||
|
||||
def __getattr__(self, function_name: str):
|
||||
def invoke(*args):
|
||||
ffi_args = []
|
||||
|
|
|
@ -489,3 +489,52 @@ func @torch.prim.dtype$int64(%t : !torch.tensor<*,si64>) -> !torch.int {
|
|||
%ret = torch.prim.dtype %t: !torch.tensor<*,si64> -> !torch.int
|
||||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.size.int$neg_dim(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.int {
|
||||
// CHECK: %[[RET:.*]] = torch.constant.int 2
|
||||
// CHECK: return %[[RET]] : !torch.int
|
||||
func @torch.aten.size.int$neg_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.int {
|
||||
%int-2 = torch.constant.int -2
|
||||
%ret = torch.aten.size.int %t, %int-2 : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int
|
||||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.size.int$pos_dim(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.int {
|
||||
// CHECK: %[[RET:.*]] = torch.constant.int 3
|
||||
// CHECK: return %[[RET]] : !torch.int
|
||||
func @torch.aten.size.int$pos_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.int {
|
||||
%int1 = torch.constant.int 1
|
||||
%ret = torch.aten.size.int %t, %int1 : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int
|
||||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.size.int$invalid_dim(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.int {
|
||||
// CHECK: %[[CST3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[RET:.*]] = torch.aten.size.int %[[T]], %[[CST3]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: return %[[RET]] : !torch.int
|
||||
func @torch.aten.size.int$invalid_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.int {
|
||||
%int3 = torch.constant.int 3
|
||||
%ret = torch.aten.size.int %t, %int3 : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int
|
||||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.tensor_static_info_cast$downcast_first(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: return %[[T]] : !torch.tensor
|
||||
func @torch.tensor_static_info_cast$downcast_first(%t: !torch.tensor) -> !torch.tensor {
|
||||
%downcast = torch.tensor_static_info_cast %t : !torch.tensor to !torch.tensor<[?,?],f64>
|
||||
%upcast = torch.tensor_static_info_cast %downcast : !torch.tensor<[?,?],f64> to !torch.tensor
|
||||
return %upcast: !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.tensor_static_info_cast$upcast_first(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f64>) -> !torch.tensor<[?,?],f64> {
|
||||
// CHECK: return %[[T]] : !torch.tensor<[?,?],f64>
|
||||
func @torch.tensor_static_info_cast$upcast_first(%t: !torch.tensor<[?,?],f64>) -> !torch.tensor<[?,?],f64> {
|
||||
%upcast = torch.tensor_static_info_cast %t : !torch.tensor<[?,?],f64> to !torch.tensor
|
||||
%downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64>
|
||||
return %downcast: !torch.tensor<[?,?],f64>
|
||||
}
|
||||
|
|
|
@ -925,3 +925,33 @@ builtin.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indic
|
|||
%ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor
|
||||
return %ret: !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @torch.aten.softmax.int(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
||||
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor
|
||||
// CHECK: return %[[RET]] : !torch.tensor
|
||||
func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
|
||||
%none = torch.constant.none
|
||||
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @torch.aten.softmax.int$specified_dtype(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
||||
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],si64>
|
||||
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],si64> to !torch.tensor
|
||||
// CHECK: return %[[RET]] : !torch.tensor
|
||||
// CHECK: }
|
||||
func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
|
||||
%int4 = torch.constant.int 4
|
||||
%ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue