E2e support for aten.softmax.int and aten.embedding

- Added a DecomposeComplexOps pass to decompose complex torchOps.
- Refactored `visitAtenArgmaxOp` and `visitAtenAnyDimOp` to
`visitReductionAlongDimIntOp`.
- Moved some helper functions into
torch-mlir/Dialect/Torch/Utils/Utils.h to be shared by multiple files.
- Added support for f64 tensor as argument and return types.
pull/353/head snapshot-20211018.30
Yi Zhang 2021-10-15 18:23:59 -04:00
parent 0902438882
commit a459e09ab7
21 changed files with 726 additions and 120 deletions

View File

@ -259,3 +259,125 @@ class GatherModule(torch.nn.Module):
@register_test_case(module_factory=lambda: GatherModule())
def GatherModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))
class AddSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, tensor):
# This is a workaround for not supporting scalar arguments.
# TODO: pass in dim as an argument to the forward method when scalar
# arguments are supported.
return tensor.add(tensor, alpha=tensor.size(1))
@register_test_case(module_factory=lambda: AddSizeIntModule())
def AddSizeIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3))
class AddSizeIntNegDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, tensor):
# This is a workaround for not supporting scalar arguments.
# TODO: pass in dim as an argument to the forward method when scalar
# arguments are supported.
return tensor.add(tensor, alpha=tensor.size(-2))
@register_test_case(module_factory=lambda: AddSizeIntNegDimModule())
def AddSizeIntNegDimModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3))
class EmbeddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.embed = torch.nn.Embedding(num_embeddings=100,
embedding_dim=50,
padding_idx=4)
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, indices):
return self.embed.forward(indices)
@register_test_case(module_factory=lambda: EmbeddingModule())
def EmbeddingModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (3, 3)))
class SoftmaxIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.softmax = torch.nn.Softmax(2)
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, tensor):
return self.softmax.forward(tensor)
@register_test_case(module_factory=lambda: SoftmaxIntModule())
def SoftmaxIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
class SoftmaxIntNegDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.softmax = torch.nn.Softmax(-2)
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, tensor):
return self.softmax.forward(tensor)
@register_test_case(module_factory=lambda: SoftmaxIntNegDimModule())
def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
class SoftmaxIntArgTypeF64Module(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.softmax = torch.nn.Softmax(2)
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, tensor):
return self.softmax.forward(tensor)
@register_test_case(module_factory=lambda: SoftmaxIntArgTypeF64Module())
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).double())

View File

@ -937,6 +937,22 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
let assemblyFormat = "$self `,` $kernel_size `,` $stride `,` $padding `,` $dilation `,` $ceil_mode attr-dict `:` type($self) `,` type($kernel_size) `,` type($stride) `,` type($padding) `,` type($dilation) `,` type($ceil_mode) `->` type($result)";
}
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::softmax.int : (Tensor, int, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
TorchOptionalIntType:$dtype
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` type($self) `,` type($dim) `,` type($dtype) `->` type($result)";
}
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
AllowsTypeRefinement,
HasValueSemantics
@ -1591,6 +1607,7 @@ def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [
Torch_IntType:$result
);
let assemblyFormat = "$self `,` $dim attr-dict `:` type($self) `,` type($dim) `->` type($result)";
let hasFolder = 1;
}
def Torch_AtenStackOp : Torch_Op<"aten.stack", [

View File

@ -819,6 +819,7 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
let hasCanonicalizer = 1;
}
def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [

View File

@ -54,6 +54,8 @@ std::unique_ptr<OperationPass<FuncOp>> createMaximizeValueSemanticsPass();
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<FuncOp>> createDecomposeComplexOpsPass();
} // namespace Torch
/// Registers all Torch transformation passes.

View File

@ -215,4 +215,20 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
}];
}
def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> {
let summary = "Decompose complicated torch operations";
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
let description = [{
Decompose torch operation that are losslessly represented as combinations of
other operations, modulo appropropriate compiler fusion. Note that this pass
is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about
systematic reductions of a large number of ops at once, guided mostly by
traits.
An example of the transformations done in this pass is:
- convert aten.softmax to softmax(x, dim)
=> tmp=exp(x); tmp / sum(tmp, dim, keepdim=True)
}];
}
#endif // TORCHMLIR_TORCH_PASSES

View File

@ -0,0 +1,25 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_DIALECT_TORCH_UTILS_H
#define TORCHMLIR_DIALECT_TORCH_UTILS_H
#include "mlir/Support/LLVM.h"
namespace mlir {
namespace torch {
namespace Torch {
int64_t toPositiveDim(int64_t dim, int64_t inputRank);
bool isValidDim(int64_t dim, int64_t inputRank);
} // namespace Torch
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_DIALECT_TORCH_UTILS_H

View File

@ -17,6 +17,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
@ -70,6 +71,35 @@ static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op,
return success();
}
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank
static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
Value inputRank) {
assert(dim.getType().isa<IntegerType>() &&
"dim arg of toPositiveDim must be integer type");
Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank);
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
Value predDimGEZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
Value dimInt = b.create<SelectOp>(loc, predDimGEZero, dim, dimAddInputRank);
return dimInt;
}
// Generate IR: assert(dim >= 0 && dim < inputRank)
static void assertIsValidDim(OpBuilder &b, Location loc, Value dim,
Value inputRank) {
assert(dim.getType().isa<IntegerType>() &&
"dim arg of assertIsValidDim must be integer type");
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
Value predGEZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
b.create<AssertOp>(loc, predGEZero,
b.getStringAttr("dim must be greater or equal to zero"));
Value predLTInputRank =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, dim, inputRank);
b.create<AssertOp>(loc, predLTInputRank,
b.getStringAttr("dim must be smaller than inputRank"));
}
// Hack to deal with the Torch list type arguments which is not supported end
// to end. Constant values can be be extracted directly and non constant
// list values are not supported.
@ -459,6 +489,41 @@ static Value createLinalgPayloadCalculationForNormOps(
return plusBias;
}
static void createLinalgPayloadCalculationForGatherOps(
OpBuilder &b, Location loc, Value input, int64_t inputRank, Value index,
int64_t dim, int64_t outputRank) {
SmallVector<Value> indices;
for (int i = 0; i < inputRank; i++) {
if (i == dim) {
indices.push_back(castIntToIndex(b, loc, index));
} else {
// `outputRank` might be larger than `inputRank`. The `linalg::IndexOp`
// takes in the dimension of the output. Add `inputDimOffset` to
// related to the correct dimension of the output for dimension larger
// than the given `dim`.
int64_t inputDimOffset = i < dim ? 0 : outputRank - inputRank;
indices.push_back(b.create<linalg::IndexOp>(loc, i + inputDimOffset));
}
}
// Assert index < input.sizes[dim]
Value indexLTInputDim = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, index,
castIndexToInt(b, loc, getDimOp(b, loc, input, dim)));
b.create<AssertOp>(loc, indexLTInputDim,
b.getStringAttr("index must be smaller than dim size"));
// Assert index >= 0
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(index.getType()));
Value indexGEThanZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, index, cst0);
b.create<AssertOp>(loc, indexGEThanZero,
b.getStringAttr("index must be larger or equal to 0"));
Value extract = b.create<tensor::ExtractOp>(loc, input, indices);
b.create<linalg::YieldOp>(loc, extract);
}
namespace {
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
public:
@ -1027,6 +1092,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
ArrayRef<Value> operands) {
if (isa<AtenTanhOp>(op))
return b.create<math::TanhOp>(loc, payloadArgs[0]);
if (isa<AtenExpOp>(op))
return b.create<math::ExpOp>(loc, payloadArgs[0]);
if (isa<AtenSigmoidOp>(op)) {
Type elementType = payloadArgs[0].getType();
auto one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
@ -1330,8 +1397,8 @@ struct ConvertElementwiseOp : ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenAddTensorOp, AtenMulTensorOp,
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp>(
op))
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
AtenExpOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1846,14 +1913,13 @@ public:
for (auto i = 0; i < inputRank; i++)
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
for (auto i = 0; i < inputRank; i++) {
if (i == dim0) {
if (i == dim0)
swapExprs.push_back(idExprs[dim1]);
} else if (i == dim1) {
else if (i == dim1)
swapExprs.push_back(idExprs[dim0]);
} else {
else
swapExprs.push_back(idExprs[i]);
}
}
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(inputRank, 0, idExprs, op.getContext()),
@ -1893,19 +1959,15 @@ public:
// Collect all the tensors to be concatenated.
auto tensorList = op.tensors();
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
if (!listConstruct)
SmallVector<Value> tensorsTorchType;
if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto tensors = llvm::to_vector<4>(
llvm::map_range(listConstruct.elements(), [&](Value tensor) -> Value {
return typeConverter->materializeTargetConversion(
rewriter, loc, getTypeConverter()->convertType(tensor.getType()),
tensor);
}));
auto tensors =
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
RankedTensorType newResultType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
int rank = newResultType.getRank();
SmallVector<Value> offsets, sizes, strides;
sizes.reserve(rank);
@ -1975,18 +2037,9 @@ public:
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, newResultTy, indices, result, affineMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
auto indexValue = args[0];
Value indexOfDim = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), indexValue);
SmallVector<Value> indices;
for (int i = 0; i < rank; i++) {
indices.push_back(i == dim
? indexOfDim
: rewriter.create<linalg::IndexOp>(loc, i));
}
Value extract =
rewriter.create<tensor::ExtractOp>(loc, self, indices);
rewriter.create<linalg::YieldOp>(loc, extract);
auto index = args[0];
createLinalgPayloadCalculationForGatherOps(b, loc, self, rank, index,
dim, rank);
});
rewriter.replaceOp(op, genericOp.getResult(0));
return success();
@ -1994,6 +2047,93 @@ public:
};
} // namespace
namespace {
class ConvertAtenEmbeddingOp : public OpConversionPattern<AtenEmbeddingOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenEmbeddingOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
AtenEmbeddingOp::Adaptor adaptor(operands);
Value weight = adaptor.weight();
Value indices = adaptor.indices();
RankedTensorType newResultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
auto weightTy = weight.getType().cast<RankedTensorType>();
if (weightTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
Value embeddingDim = getDimOp(rewriter, loc, weight, 1);
Type elemTy = weightTy.getElementType();
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, indices);
sizes.push_back(embeddingDim);
int64_t resultRank = sizes.size();
auto indicesTy = weight.getType().cast<RankedTensorType>();
int64_t indicesRank = indicesTy.getRank();
SmallVector<AffineExpr> indicesExprs;
for (int i = 0; i < indicesRank; i++)
indicesExprs.push_back(rewriter.getAffineDimExpr(i));
auto indicesAffineMap = AffineMap::get(
/*dimCount=*/resultRank,
/*symbolCount=*/0, indicesExprs, op->getContext());
SmallVector<AffineMap, 2> indexingMaps = {
indicesAffineMap,
rewriter.getMultiDimIdentityMap(resultRank),
};
SmallVector<StringRef> iteratorTypes(sizes.size(),
getParallelIteratorTypeName());
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, sizes, elemTy);
Value embeddingResult =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), indices, initTensor,
/*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value index = args[0];
createLinalgPayloadCalculationForGatherOps(
b, loc, weight, weightTy.getRank(), index, /*dim=*/0,
resultRank);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
embeddingResult);
return success();
}
};
} // namespace
namespace {
class ConvertAtenSizeIntOp : public OpConversionPattern<AtenSizeIntOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenSizeIntOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
AtenSizeIntOp::Adaptor adaptor(operands);
Value self = adaptor.self();
Value dim = adaptor.dim();
auto type = self.getType().cast<RankedTensorType>();
Value inputRank = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(type.getRank()));
Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank);
assertIsValidDim(rewriter, loc, dimPositive, inputRank);
Value size = rewriter.create<tensor::DimOp>(
loc, adaptor.self(), castIntToIndex(rewriter, loc, dimPositive));
rewriter.replaceOp(op, castIndexToInt(rewriter, loc, size));
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -2057,6 +2197,10 @@ public:
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenArgmaxOp>();
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);
target.addIllegalOp<AtenSizeIntOp>();
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
target.addIllegalOp<AtenEmbeddingOp>();
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(Utils)

View File

@ -13,6 +13,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/StringMap.h"
using namespace mlir;
@ -505,6 +506,29 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}
//===----------------------------------------------------------------------===//
// AtenSizeIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
auto type = getOperand(0).getType().dyn_cast<BaseTensorType>();
if (!type || !type.hasSizes())
return nullptr;
int64_t inputRank = type.getSizes().size();
int64_t dim;
if (!matchPattern(this->dim(), m_TorchConstantInt(&dim)))
return nullptr;
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return nullptr;
if (type.getSizes()[dim] == kUnknownSize)
return nullptr;
return IntegerAttr::get(IntegerType::get(getContext(), 64),
type.getSizes()[dim]);
}
//===----------------------------------------------------------------------===//
// AtenGtIntOp
//===----------------------------------------------------------------------===//
@ -655,6 +679,19 @@ bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
outputs[0].cast<BaseTensorType>());
}
void TensorStaticInfoCastOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) {
auto reverseCast =
op.operand().getDefiningOp<Torch::TensorStaticInfoCastOp>();
if (!reverseCast || reverseCast.operand().getType() != op.getType())
return failure();
rewriter.replaceOp(op, reverseCast.operand());
return success();
});
}
//===----------------------------------------------------------------------===//
// CopyToNonValueTensorOp
//===----------------------------------------------------------------------===//

View File

@ -1,5 +1,6 @@
add_mlir_library(TorchMLIRTorchPasses
AdjustCallingConventions.cpp
DecomposeComplexOps.cpp
Passes.cpp
GlobalizeObjectGraph.cpp
InlineGlobalSlots.cpp
@ -23,6 +24,7 @@ add_mlir_library(TorchMLIRTorchPasses
MLIRPass
MLIRTransforms
TorchMLIRTorchDialect
TorchMLIRTorchUtils
)
torch_mlir_target_includes(TorchMLIRTorchPasses)

View File

@ -0,0 +1,100 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
// Decompose softmax into: exp(x) / sum(exp(x))
namespace {
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
Value dim = op.dim();
if (!op.dtype().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented non-None dtype for softmax");
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
// exp(x)
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
// sum(exp(x))
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(dim.getType()), dim);
Value keepDim = rewriter.create<ConstantBoolOp>(loc, true);
Value dtype = rewriter.create<ConstantNoneOp>(loc);
SmallVector<int64_t> sizes;
int64_t dimInt;
if (tensorType.hasSizes()) {
ArrayRef<int64_t> inputShape = tensorType.getSizes();
int64_t inputRank = inputShape.size();
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
sizes.append(inputShape.begin(), inputShape.end());
sizes[dimInt] = 1;
} else {
sizes.resize(inputRank, kUnknownSize);
}
}
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
tensorType.getDtype());
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, exp,
dimList, keepDim, dtype);
// exp(x) / sum(exp(x))
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
result);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect>();
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
target.addIllegalOp<AtenSoftmaxIntOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::Torch::createDecomposeComplexOpsPass() {
return std::make_unique<DecomposeComplexOpsPass>();
}

View File

@ -90,7 +90,7 @@ public:
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
copyToValueTensorOps.push_back(copyToValueTensor);
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
AtenTransposeIntOp>(op)) {
AtenTransposeIntOp, TensorStaticInfoCastOp>(op)) {
viewLikeOps.push_back(op);
llvm::append_range(workList, op->getResult(0).getUsers());
} else {

View File

@ -119,9 +119,17 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// Do shape and dtype refinement.
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends.
pm.addPass(Torch::createRefinePublicReturnPass());
if (options.optimize) {
// This can fold away some branches given the information got from
// RefineTypes before doing maximize value sematics which only works with
// basic blocks.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
}
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
@ -134,6 +142,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// only-used-in-training operations on `torch.global_slot`'s.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
}
pm.addNestedPass<FuncOp>(Torch::createDecomposeComplexOpsPass());
// TODO: VerifyTorchBackendContractPass.
}

View File

@ -19,6 +19,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
using namespace mlir;
using namespace mlir::torch;
@ -263,7 +264,8 @@ public:
} else if (auto arangeStart = dyn_cast<AtenArangeStartOp>(op)) {
return visitAtenArangeStartOp(arangeStart);
} else if (auto sum = dyn_cast<AtenSumOp>(op)) {
return visitReductionAlongAllDimsOp(sum, operands);
Type dtype = operands[0]->getValue().dtype;
return visitReductionAlongAllDimsOp(sum, dtype, operands);
} else if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
return visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
sumDimIntList.keepdim(), operands);
@ -271,9 +273,17 @@ public:
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
meanDim.keepdim(), operands);
} else if (auto argmax = dyn_cast<AtenArgmaxOp>(op)) {
return visitAtenArgmaxOp(argmax, operands);
Value dim = argmax.dim();
Type dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
if (dim.getType().isa<Torch::NoneType>())
return visitReductionAlongAllDimsOp(op, dtype, operands);
if (dim.getType().isa<Torch::IntType>())
return visitReductionAlongDimIntOp(argmax, argmax.dim(),
argmax.keepdim(), dtype, operands);
} else if (auto anyDim = dyn_cast<AtenAnyDimOp>(op)) {
return visitAtenAnyDimOp(anyDim, operands);
Type dtype = operands[0]->getValue().dtype;
return visitReductionAlongDimIntOp(anyDim, anyDim.dim(), anyDim.keepdim(),
dtype, operands);
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
return visitReshapeLikeOp(view, operands);
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) {
@ -353,6 +363,8 @@ public:
return visitAtenEmbeddingOp(embedding, operands);
} else if (auto bmm = dyn_cast<AtenBmmOp>(op)) {
return visitAtenBmmOp(bmm, operands);
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
return visitAtenSoftmaxIntOp(softmaxIntOp, operands);
}
// Otherwise, this is an unknown operation. Just mark all results as
@ -394,15 +406,13 @@ private:
ChangeResult visitAtenArangeStartOp(AtenArangeStartOp op);
ChangeResult visitAtenArangeOp(AtenArangeOp op);
ChangeResult visitReductionAlongAllDimsOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
Operation *op, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitReductionAlongDimIntListOp(
Operation *op, Value dim, Value keepdim,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenArgmaxOp(AtenArgmaxOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenAnyDimOp(AtenAnyDimOp op,
ChangeResult visitReductionAlongDimIntOp(
Operation *op, Value dim, Value keepdim, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
template <typename OpTy>
ChangeResult
@ -448,27 +458,34 @@ private:
ChangeResult
visitAtenBmmOp(AtenBmmOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
};
} // namespace
static int64_t toPositiveDim(int64_t dim, int64_t inputRank) {
return dim >= 0 ? dim : dim + inputRank;
}
static bool isValidDim(int64_t dim, int64_t inputRank) {
return dim >= 0 && dim < inputRank;
// Get the MLIR type of the tensor dtype given the dtype integer value and the
// input dtype. When DType is None the type is inferred from the input dtype.
static void fillInDTypeGivenDTypeIntAndInputDType(MLIRContext *context,
ValueKnowledge &knowledge,
Value dtype,
Type inputDType) {
int64_t dtypeInt;
if (dtype.getType().isa<Torch::NoneType>())
knowledge.dtype = inputDType;
else if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
knowledge.dtype = getTypeFromDTypeInteger(context, dtypeInt);
}
// Get the MLIR type of the tensor dtype given the dtype integer value and data
// type. When DType is None the type is inferred from the data type.
// type of torch type. When DType is None the type is inferred from the data
// type.
static void fillInDTypeGivenDTypeAndDataType(MLIRContext *context,
ValueKnowledge &knowledge,
Value dtype, Type dataType) {
int64_t dtypeInt;
if (dtype.getType().isa<Torch::NoneType>())
knowledge.dtype = getDTypeFromTorchType(context, dataType);
else if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
knowledge.dtype = getTypeFromDTypeInteger(context, dtypeInt);
Type dtypeFromDataType = getDTypeFromTorchType(context, dataType);
fillInDTypeGivenDTypeIntAndInputDType(context, knowledge, dtype,
dtypeFromDataType);
}
static void fillInSizesGivenSizesList(ValueKnowledge &knowledge, Value sizes) {
@ -718,10 +735,10 @@ ChangeResult TypeAnalyzer::visitAtenArangeOp(AtenArangeOp op) {
}
ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
Operation *op, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
knowledge.dtype = dtype;
// Reduction along all dims always results in a tensor of rank zero,
// which is represented by the default empty `knowledge.sizes` vector
knowledge.hasSizes = true;
@ -764,53 +781,28 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
}
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenArgmaxOp(
AtenArgmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
int64_t dim;
bool keepDim;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
int64_t inputRank = input.sizes.size();
knowledge.hasSizes = true;
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
knowledge.sizes = input.sizes;
dim = toPositiveDim(dim, inputRank);
if (isValidDim(dim, inputRank)) {
if (keepDim)
knowledge.sizes[dim] = 1;
else
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
}
} else if (op.dim().getType().isa<IntegerType>())
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1,
kUnknownSize);
}
// If dim is no kind of Integer, keepDim is ignored,
// and the result will bea rank-0 tensor
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
Operation *op, Value dim, Value keepdim, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
assert(dim.getType().isa<Torch::IntType>() && "dim must be int type");
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
int64_t dim;
knowledge.dtype = dtype;
int64_t dimInt;
bool keepDim;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) {
int64_t inputRank = input.sizes.size();
knowledge.hasSizes = true;
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
knowledge.sizes = input.sizes;
dim = toPositiveDim(dim, inputRank);
if (isValidDim(dim, inputRank)) {
dimInt = toPositiveDim(dimInt, inputRank);
if (isValidDim(dimInt, inputRank)) {
if (keepDim)
knowledge.sizes[dim] = 1;
knowledge.sizes[dimInt] = 1;
else
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
knowledge.sizes.erase(knowledge.sizes.begin() + dimInt);
}
} else {
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, kUnknownSize);
@ -1081,6 +1073,19 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
return getLatticeElement(op.getResult()).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenSoftmaxIntOp(
AtenSoftmaxIntOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto dtype = op.dtype();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = input.hasSizes;
knowledge.sizes = input.sizes;
fillInDTypeGivenDTypeIntAndInputDType(op->getContext(), knowledge, dtype,
input.dtype);
return getLatticeElement(op.getResult()).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenBmmOp(
AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =

View File

@ -0,0 +1,6 @@
add_mlir_dialect_library(TorchMLIRTorchUtils
Utils.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Utils
)

View File

@ -0,0 +1,26 @@
//===----------------------------------------------------------------------===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
namespace mlir {
namespace torch {
namespace Torch {
int64_t toPositiveDim(int64_t dim, int64_t inputRank) {
return dim >= 0 ? dim : dim + inputRank;
}
bool isValidDim(int64_t dim, int64_t inputRank) {
return dim >= 0 && dim < inputRank;
}
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -45,6 +45,8 @@ static bool isArgMemRefTypeValid(Type type) {
Type elemTy = memRefType.getElementType();
if (elemTy.isa<Float32Type>()) {
return true;
} else if (elemTy.isa<Float64Type>()) {
return true;
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
if (integerTy.isSignlessInteger(64))
return true;
@ -81,7 +83,8 @@ static LogicalResult mungeFunction(
for (auto arg : func.getArguments()) {
auto type = arg.getType();
if (!isArgMemRefTypeValid(type))
return emitError(arg.getLoc(), "argument must be a memref of f32 or i64");
return emitError(arg.getLoc(),
"argument must be a memref of f32, f64, i64");
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
arg.replaceAllUsesExcept(cast, cast);
arg.setType(getAbiTypeForMemRef(type));
@ -91,12 +94,17 @@ static LogicalResult mungeFunction(
SmallVector<Operation *> toErase;
bool hadError = false;
func.walk([&](ReturnOp op) {
auto returnType =
op.getOperandTypes()[0].dyn_cast<MemRefType>().getElementType();
auto memRefType = op.getOperandTypes()[0].dyn_cast<MemRefType>();
if (!memRefType) {
hadError = true;
op.emitError("return value must be memref type");
return;
}
auto returnType = memRefType.getElementType();
auto it = consumeFuncReturnFuncs.find(returnType);
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
hadError = true;
op.emitError("must have one return value: a memref of f32 or i64");
op.emitError("must have one return value: a memref of f32, i64 or f64");
return;
}
@ -126,26 +134,27 @@ class MungeCallingConventions
void runOnOperation() override {
auto module = getOperation();
OpBuilder b(module.getBodyRegion());
auto consumeFuncReturnInt64Func = b.create<FuncOp>(
module.getLoc(), "refbackend_consume_int64_func_return",
FunctionType::get(
module.getContext(),
UnrankedMemRefType::get(b.getI64Type(), /*memorySpace=*/0), {}),
b.getStringAttr("private"));
auto consumeFuncReturnFloat32Func = b.create<FuncOp>(
module.getLoc(), "refbackend_consume_float32_func_return",
FunctionType::get(
module.getContext(),
UnrankedMemRefType::get(b.getF32Type(), /*memorySpace=*/0), {}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnInt64Func);
addEmitCInterfaceAttr(consumeFuncReturnFloat32Func);
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
consumeFuncReturnFuncs[b.getF32Type()] = consumeFuncReturnFloat32Func;
consumeFuncReturnFuncs[b.getI64Type()] = consumeFuncReturnInt64Func;
DenseSet<FuncOp> consumeFuncReturnFuncsSet;
auto createConsumeFuncReturnFunc = [&](Type elemTy, std::string funcName) {
auto consumeFuncReturnFunc = b.create<FuncOp>(
module.getLoc(), funcName,
FunctionType::get(module.getContext(),
UnrankedMemRefType::get(elemTy, /*memorySpace=*/0),
{}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc);
consumeFuncReturnFuncs[elemTy] = consumeFuncReturnFunc;
consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc);
};
createConsumeFuncReturnFunc(b.getI64Type(),
"refbackend_consume_int64_func_return");
createConsumeFuncReturnFunc(b.getF32Type(),
"refbackend_consume_float32_func_return");
createConsumeFuncReturnFunc(b.getF64Type(),
"refbackend_consume_float64_func_return");
for (auto func : module.getOps<FuncOp>()) {
if (func == consumeFuncReturnInt64Func ||
func == consumeFuncReturnFloat32Func)
if (consumeFuncReturnFuncsSet.contains(func))
continue;
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
return signalPassFailure();

View File

@ -482,6 +482,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit(
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
)
emit(
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
)
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
@ -525,7 +528,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
emit("aten::size.int : (Tensor, int) -> (int)")
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::sum : (Tensor, int?) -> (Tensor)")
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")

View File

@ -24,13 +24,8 @@ __all__ = [
def checkArgTypeIsSupported(ty):
if ty == np.float32:
return
elif ty == np.int64:
return
assert False, "Only tensor argument of float32 and int64 are supported but got " + str(
ty)
SUPPORTED = [np.float32, np.float64, np.int64]
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
class RefBackendInvoker:
def __init__(self, module):
@ -45,12 +40,19 @@ class RefBackendInvoker:
def consume_f32_return(a):
self.result = unranked_memref_to_numpy(a, np.float32)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_f64_return(a):
self.result = unranked_memref_to_numpy(a, np.float64)
self.ee.register_runtime("refbackend_consume_int64_func_return",
consume_i64_return)
self.ee.register_runtime("refbackend_consume_float32_func_return",
consume_f32_return)
self.ee.register_runtime("refbackend_consume_float64_func_return",
consume_f64_return)
def __getattr__(self, function_name: str):
def invoke(*args):
ffi_args = []

View File

@ -489,3 +489,52 @@ func @torch.prim.dtype$int64(%t : !torch.tensor<*,si64>) -> !torch.int {
%ret = torch.prim.dtype %t: !torch.tensor<*,si64> -> !torch.int
return %ret : !torch.int
}
// CHECK-LABEL: func @torch.aten.size.int$neg_dim(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.int {
// CHECK: %[[RET:.*]] = torch.constant.int 2
// CHECK: return %[[RET]] : !torch.int
func @torch.aten.size.int$neg_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.int {
%int-2 = torch.constant.int -2
%ret = torch.aten.size.int %t, %int-2 : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int
return %ret : !torch.int
}
// CHECK-LABEL: func @torch.aten.size.int$pos_dim(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.int {
// CHECK: %[[RET:.*]] = torch.constant.int 3
// CHECK: return %[[RET]] : !torch.int
func @torch.aten.size.int$pos_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.int {
%int1 = torch.constant.int 1
%ret = torch.aten.size.int %t, %int1 : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int
return %ret : !torch.int
}
// CHECK-LABEL: func @torch.aten.size.int$invalid_dim(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.int {
// CHECK: %[[CST3:.*]] = torch.constant.int 3
// CHECK: %[[RET:.*]] = torch.aten.size.int %[[T]], %[[CST3]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int
// CHECK: return %[[RET]] : !torch.int
func @torch.aten.size.int$invalid_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.int {
%int3 = torch.constant.int 3
%ret = torch.aten.size.int %t, %int3 : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int
return %ret : !torch.int
}
// CHECK-LABEL: func @torch.tensor_static_info_cast$downcast_first(
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[T]] : !torch.tensor
func @torch.tensor_static_info_cast$downcast_first(%t: !torch.tensor) -> !torch.tensor {
%downcast = torch.tensor_static_info_cast %t : !torch.tensor to !torch.tensor<[?,?],f64>
%upcast = torch.tensor_static_info_cast %downcast : !torch.tensor<[?,?],f64> to !torch.tensor
return %upcast: !torch.tensor
}
// CHECK-LABEL: func @torch.tensor_static_info_cast$upcast_first(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f64>) -> !torch.tensor<[?,?],f64> {
// CHECK: return %[[T]] : !torch.tensor<[?,?],f64>
func @torch.tensor_static_info_cast$upcast_first(%t: !torch.tensor<[?,?],f64>) -> !torch.tensor<[?,?],f64> {
%upcast = torch.tensor_static_info_cast %t : !torch.tensor<[?,?],f64> to !torch.tensor
%downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64>
return %downcast: !torch.tensor<[?,?],f64>
}

View File

@ -925,3 +925,33 @@ builtin.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indic
%ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor
return %ret: !torch.tensor
}
// ----
// CHECK-LABEL: func @torch.aten.softmax.int(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
%none = torch.constant.none
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}
// ----
// CHECK-LABEL: func @torch.aten.softmax.int$specified_dtype(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[DTYPE:.*]] = torch.constant.int 4
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],si64>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],si64> to !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
// CHECK: }
func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
%int4 = torch.constant.int 4
%ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}