Add support for constant_pad_nd

Note that to enable folding of the code coming from an example
like the ConstantPad2dStaticModule e2e test, support for other
operations had to be added/improved:
- aten::neg.int
- aten::eq.float
- aten::eq.str
- prim::Uninitialized
pull/519/head snapshot-20220111.200
Liam Fitzpatrick 2022-01-11 07:42:53 +00:00 committed by Yi Zhang
parent 35cf8d18f7
commit 077e55d756
10 changed files with 391 additions and 42 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
*.swp
.cache/
.vscode
.env
*.code-workspace

View File

@ -11,7 +11,6 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -38,7 +37,6 @@ def MmModule_chained(module, tu: TestUtils):
# ==============================================================================
class BmmModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -57,10 +55,8 @@ class BmmModule(torch.nn.Module):
def BmmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
# ==============================================================================
# A subgraph with multiple mm ops.
class MmDagModule(torch.nn.Module):
def __init__(self):
@ -80,10 +76,8 @@ class MmDagModule(torch.nn.Module):
def MmDagModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
# ==============================================================================
class MmTanhModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -100,8 +94,6 @@ class MmTanhModule(torch.nn.Module):
def matmul(self, lhs, rhs):
return torch.mm(lhs, rhs)
# ==============================================================================
@register_test_case(module_factory=lambda: MmTanhModule())
def MmTanhModule_basic(module, tu: TestUtils):
@ -109,7 +101,6 @@ def MmTanhModule_basic(module, tu: TestUtils):
# ==============================================================================
class AddmmModuleFloat(torch.nn.Module):
def __init__(self):
super().__init__()
@ -196,7 +187,6 @@ def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
# ==============================================================================
class FlattenStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -217,7 +207,6 @@ def FlattenStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
class FlattenRank0Module(torch.nn.Module):
def __init__(self):
super().__init__()
@ -238,7 +227,6 @@ def FlattenRank0Module_basic(module, tu: TestUtils):
# ==============================================================================
class FlattenDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -259,7 +247,6 @@ def FlattenDynamicModule_basic(module, tu: TestUtils):
# ==============================================================================
class MaxPool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -276,14 +263,86 @@ class MaxPool2dModule(torch.nn.Module):
def forward(self, x):
return self.mp2d(x)
# ==============================================================================
@register_test_case(module_factory=lambda: MaxPool2dModule())
def MaxPool2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
class ConstantPad2dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pad2d = torch.nn.ConstantPad2d((0, 1, 2, 3), -float('inf'))
@export
@annotate_args([
None,
([1, 1, 20, 20], torch.float32, True),
])
def forward(self, x):
return self.pad2d(x)
@register_test_case(module_factory=lambda: ConstantPad2dStaticModule())
def ConstantPad2dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
# ==============================================================================
class ConstantPadNdModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf'))
@register_test_case(module_factory=lambda: ConstantPadNdModule())
def ConstantPadNdModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
class ConstantPadNdStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 1, 20, 20, 4, 4], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf'))
@register_test_case(module_factory=lambda: ConstantPadNdStaticModule())
def ConstantPadNdStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
class ConstantPadNdPartialStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 1, 20, 20, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.constant_pad_nd(x, (0, 1, 2, 3), -float('inf'))
@register_test_case(module_factory=lambda: ConstantPadNdPartialStaticModule())
def ConstantPadNdPartialStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
# ==============================================================================
class TransposeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -296,13 +355,13 @@ class TransposeIntModule(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, 0, 1)
# ==============================================================================
@register_test_case(module_factory=lambda: TransposeIntModule())
def TransposeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))
# ==============================================================================
class PermuteModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -333,13 +392,12 @@ class TransposeIntNegDimsModule(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, -1, -2)
# ==============================================================================
@register_test_case(module_factory=lambda: TransposeIntNegDimsModule())
def TransposeIntNegDimsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))
# ==============================================================================
class PermuteNegativeIndexModule(torch.nn.Module):
def __init__(self):
@ -353,11 +411,12 @@ class PermuteNegativeIndexModule(torch.nn.Module):
def forward(self, x):
return x.permute(0, -1, 1)
# ==============================================================================
@register_test_case(module_factory=lambda: PermuteNegativeIndexModule())
def PermuteNegativeIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))
# ==============================================================================
class TensorsConcatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -379,7 +438,6 @@ def TensorsConcatModule_basic(module, tu: TestUtils):
# ==============================================================================
class GatherModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -422,7 +480,6 @@ def AddSizeIntModule_basic(module, tu: TestUtils):
# ==============================================================================
class AddSizeIntNegDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -505,7 +562,6 @@ def _SoftmaxModule_basic(module, tu: TestUtils):
# ==============================================================================
class SoftmaxIntNegDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -527,7 +583,6 @@ def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
# ==============================================================================
class SoftmaxIntArgTypeF64Module(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -1778,6 +1778,22 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)";
}
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchIntListType:$pad,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $pad `,` $value attr-dict `:` type($self) `,` type($pad) `,` type($value) `->` type($result)";
}
def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [
AllowsTypeRefinement
]> {
@ -2915,6 +2931,22 @@ def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
}
def Torch_AtenEqStrOp : Torch_Op<"aten.eq.str", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::eq.str : (str, str) -> (bool)`";
let arguments = (ins
Torch_StringType:$a,
Torch_StringType:$b
);
let results = (outs
Torch_BoolType:$result
);
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
let hasFolder = 1;
}
def Torch_AtenStrOp : Torch_Op<"aten.str", [
AllowsTypeRefinement,
HasValueSemantics
@ -3175,6 +3207,21 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
let hasFolder = 1;
}
def Torch_AtenNegIntOp : Torch_Op<"aten.neg.int", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::neg.int : (int) -> (int)`";
let arguments = (ins
Torch_IntType:$a
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
let hasFolder = 1;
}
def Torch_AtenLogIntOp : Torch_Op<"aten.log.int", [
AllowsTypeRefinement,
HasValueSemantics
@ -3248,6 +3295,22 @@ def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
}
def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::eq.float : (float, float) -> (bool)`";
let arguments = (ins
Torch_FloatType:$a,
Torch_FloatType:$b
);
let results = (outs
Torch_BoolType:$result
);
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
let hasFolder = 1;
}
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -185,6 +185,7 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
AnyTorchType:$result
);
let assemblyFormat = " attr-dict `:` type($result)";
let hasCanonicalizer = 1;
}
def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [

View File

@ -45,6 +45,21 @@ struct torch_constant_int_op_binder {
return false;
}
};
struct torch_constant_float_op_binder {
double *bind_value;
/// Creates a matcher instance that binds the value to bv if match succeeds.
torch_constant_float_op_binder(double *bv) : bind_value(bv) {}
bool match(Operation *op) {
if (auto constantFloat = dyn_cast<Torch::ConstantFloatOp>(op)) {
*bind_value = constantFloat.value().convertToDouble();
return true;
}
return false;
}
};
} // namespace detail
/// Matches the integer stored in a `torch.constant.bool`.
@ -53,6 +68,12 @@ m_TorchConstantInt(int64_t *bind_value) {
return detail::torch_constant_int_op_binder(bind_value);
}
/// Matches the float value stored in a `torch.constant.float`.
inline detail::torch_constant_float_op_binder
m_TorchConstantFloat(double *bind_value) {
return detail::torch_constant_float_op_binder(bind_value);
}
namespace detail {
/// Matches the bool stored in a `torch.constant.bool`.
struct torch_constant_bool_op_binder {

View File

@ -275,7 +275,27 @@ static SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
}
// Helper function to get the padding tensor given the padding int values.
// It's assumed that the padding on the low end and high end are the same.
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<int64_t> &lowPaddingInts,
SmallVectorImpl<int64_t> &highPaddingInts,
Value pad) {
Location loc = op->getLoc();
Type rankedTensorType = linalg::PadTensorOp::inferResultType(
input.getType().cast<RankedTensorType>(), lowPaddingInts,
highPaddingInts);
SmallVector<OpFoldResult> lowPaddings =
getAsOpFoldResult(b, loc, lowPaddingInts);
SmallVector<OpFoldResult> highPaddings =
getAsOpFoldResult(b, loc, highPaddingInts);
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
/*packing=*/false, loc, b);
return paddedInput;
}
// Helper function to get the padding tensor given the padding int values.
// It's assumed that the padding on the low end and high end are the same,
// and that zero padding is required.
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<int64_t> &paddingInts) {
assert(input.getType().isa<RankedTensorType>() &&
@ -284,13 +304,7 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
Value c0 = b.create<arith::ConstantOp>(
loc,
b.getZeroAttr(input.getType().cast<RankedTensorType>().getElementType()));
SmallVector<OpFoldResult> paddings = getAsOpFoldResult(b, loc, paddingInts);
Type ranked4DTensorType = linalg::PadTensorOp::inferResultType(
input.getType().cast<RankedTensorType>(), paddingInts, paddingInts);
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
ranked4DTensorType, input, c0, /*low=*/paddings, /*high=*/paddings,
/*packing=*/false, loc, b);
return paddedInput;
return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0);
}
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
@ -2685,6 +2699,57 @@ public:
};
} // namespace
namespace {
class ConvertAtenConstantPadNdOp
: public OpConversionPattern<AtenConstantPadNdOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenConstantPadNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value self = adaptor.self();
auto type = self.getType().cast<RankedTensorType>();
int64_t rank = type.getRank();
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
SmallVector<int64_t> padInts;
if (!matchPattern(op.pad(), m_TorchConstantIntList(padInts)))
return rewriter.notifyMatchFailure(
op, "only support constant int pad ranges");
uint64_t padRank = padInts.size() / 2;
if (padRank * 2 != padInts.size())
return rewriter.notifyMatchFailure(op, "pad range size is not even");
if (rank < 0 || padRank > (uint64_t)rank)
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
// Initialize low/high paddings with the dims that should not be padded.
SmallVector<int64_t, 4> lowPadding(/*Size=*/rank - padRank, /*Value=*/0);
SmallVector<int64_t, 4> highPadding(/*Size=*/rank - padRank, /*Value=*/0);
// Add the requested padding - note op.pad() is highest dim first ordered
// pairs of low,high.
for (uint64_t i = padRank; i > 0; --i) {
lowPadding.push_back(padInts[i * 2 - 2]);
highPadding.push_back(padInts[i * 2 - 1]);
}
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<RankedTensorType>().getElementType();
Value castedValue =
convertScalarToDtype(rewriter, loc, adaptor.value(), elementType);
Value paddedInput = getPaddedTensor(op, rewriter, self, lowPadding,
highPadding, castedValue);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
return success();
}
};
} // namespace
namespace {
class ConvertAtenFlattenUsingIntsOp
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
@ -4225,6 +4290,8 @@ public:
patterns.add<ConvertAtenViewOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenConstantPadNdOp>();
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
target.addIllegalOp<AtenSumOp>();
patterns.add<ConvertReductionOp>(typeConverter, context);
target.addIllegalOp<AtenTransposeIntOp>();

View File

@ -13,8 +13,10 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Casting.h"
using namespace mlir;
using namespace mlir::torch;
@ -653,6 +655,36 @@ OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
[](int64_t a, int64_t b) { return a == b; });
}
//===----------------------------------------------------------------------===//
// AtenEqFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) {
double lhs, rhs;
if (!matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)) ||
!matchPattern(getOperand(1), m_TorchConstantFloat(&rhs)))
return nullptr;
return getI1IntegerAttr(getContext(), lhs == rhs);
}
//===----------------------------------------------------------------------===//
// AtenEqStrOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
if (getOperand(0) == getOperand(1))
return getI1IntegerAttr(getContext(), true);
auto aStr = a().getDefiningOp<ConstantStrOp>();
auto bStr = b().getDefiningOp<ConstantStrOp>();
if (aStr && bStr)
return getI1IntegerAttr(getContext(), aStr == bStr);
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenLtIntOp
//===----------------------------------------------------------------------===//
@ -1005,6 +1037,20 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}
//===----------------------------------------------------------------------===//
// PrimUninitializedOp
//===----------------------------------------------------------------------===//
void PrimUninitializedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) {
if (!op.use_empty())
return failure();
rewriter.eraseOp(op);
return success();
});
}
//===----------------------------------------------------------------------===//
// PrimTupleUnpackOp
//===----------------------------------------------------------------------===//
@ -1129,6 +1175,17 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenNegIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getI64IntegerAttr(getContext(), -c);
return nullptr;
}
//===----------------------------------------------------------------------===//
// PrimDtypeOp
//===----------------------------------------------------------------------===//

View File

@ -490,6 +490,8 @@ public:
return visitAtenNllLossForwardOp(nllForwardOp, operands);
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
} else if (auto constantPadNdOp = dyn_cast<AtenConstantPadNdOp>(op)) {
return visitAtenConstantPadNdOp(constantPadNdOp, operands);
}
// Otherwise, this is an unknown operation. Just mark all results as
@ -513,6 +515,9 @@ private:
ChangeResult
visitAtenMaxPool2dOp(AtenMaxPool2dOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenConstantPadNdOp(AtenConstantPadNdOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitAtenAdaptiveAvgPool2dOp(
AtenAdaptiveAvgPool2dOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
@ -920,18 +925,18 @@ ChangeResult TypeAnalyzer::visitAtenConv2dOp(
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
auto &ifm = operands[0]->getValue();
auto &input = operands[0]->getValue();
auto &weights = operands[1]->getValue();
if (weights.hasSizes && ifm.hasSizes)
if (weights.hasSizes && input.hasSizes)
knowledge.sizes = computeOpWithKernelOutputShape(
op, ifm, weights.sizes[0], weights.sizes[2], weights.sizes[3]);
op, input, weights.sizes[0], weights.sizes[2], weights.sizes[3]);
else
knowledge.sizes.resize(4, kUnknownSize);
// Running some experiments in PyTorch, the bias doesn't seem to
// contribute to the final element type.
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(op->getContext(),
{&ifm, &weights});
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
op->getContext(), {&input, &weights});
return getLatticeElement(op->getResult(0)).join(knowledge);
}
@ -940,19 +945,45 @@ ChangeResult TypeAnalyzer::visitAtenMaxPool2dOp(
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
auto &ifm = operands[0]->getValue();
auto &input = operands[0]->getValue();
SmallVector<int64_t, 2> kernelSize;
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))
kernelSize = SmallVector<int64_t, 2>{kUnknownSize, kUnknownSize};
if (ifm.hasSizes)
if (input.hasSizes)
knowledge.sizes = computeOpWithKernelOutputShape(
op, ifm, ifm.sizes[1], kernelSize[0], kernelSize[1]);
op, input, input.sizes[1], kernelSize[0], kernelSize[1]);
else
knowledge.sizes.resize(4, kUnknownSize);
knowledge.dtype = operands[0]->getValue().dtype;
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenConstantPadNdOp(
AtenConstantPadNdOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto &input = operands[0]->getValue();
if (input.hasSizes) {
knowledge.hasSizes = true;
SmallVector<int64_t> padInts;
if (matchPattern(op.pad(), m_TorchConstantIntList(padInts))) {
knowledge.sizes = input.sizes;
uint64_t padRank = padInts.size() / 2;
uint64_t padOffset = knowledge.sizes.size() - padRank;
// op.pad() is highest dim first ordered pairs of low,high.
for (uint64_t i = padRank, r = padOffset; i > 0; --i, ++r) {
if (knowledge.sizes[r] != kUnknownSize)
knowledge.sizes[r] += padInts[i * 2 - 2] + padInts[i * 2 - 1];
}
} else
knowledge.sizes.resize(input.sizes.size(), kUnknownSize);
}
knowledge.dtype = operands[0]->getValue().dtype;
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenAdaptiveAvgPool2dOp(
AtenAdaptiveAvgPool2dOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {

View File

@ -414,7 +414,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
emit("prim::max.self_int : (int[]) -> (int)")
emit("prim::max.int : (int, int) -> (int)")
emit("prim::RaiseException : (str) -> ()")
emit("prim::Uninitialized : () -> (Any)")
emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True)
emit("prim::unchecked_cast : (t) -> (t)",
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
emit("prim::Print : (...) -> ()")
@ -540,6 +540,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
@ -619,6 +620,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
# Str ops.
emit("aten::add.str : (str, str) -> (str)")
emit("aten::eq.str : (str, str) -> (bool)", has_folder=True)
emit("aten::str : (t) -> (str)")
emit("aten::format : (...) -> (str)")
emit("aten::join : (str, str[]) -> (str)")
@ -640,11 +642,13 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)")
emit("aten::add.float_int : (float, int) -> (float)")
emit("aten::mul.float : (float, float) -> (float)")
emit("aten::neg.float : (float) -> (float)")
emit("aten::lt.float_int : (float, int) -> (bool)")
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
emit("aten::__and__.bool : (bool, bool) -> (bool)")
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)

View File

@ -249,6 +249,55 @@ func @torch.aten.ge.int$same_value() -> !torch.bool {
return %2 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.eq.float$different_value() -> !torch.bool {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[FALSE]] : !torch.bool
func @torch.aten.eq.float$different_value() -> !torch.bool {
%float4 = torch.constant.float 4.0
%float5 = torch.constant.float 5.0
%2 = torch.aten.eq.float %float4, %float5 : !torch.float, !torch.float -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.eq.float$same_value() -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
func @torch.aten.eq.float$same_value() -> !torch.bool {
%float4 = torch.constant.float 4.0
%float4_0 = torch.constant.float 4.0
%2 = torch.aten.eq.float %float4, %float4_0 : !torch.float, !torch.float -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.eq.str$different_value() -> !torch.bool {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[FALSE]] : !torch.bool
func @torch.aten.eq.str$different_value() -> !torch.bool {
%str4 = torch.constant.str "4"
%str5 = torch.constant.str "5"
%2 = torch.aten.eq.str %str4, %str5 : !torch.str, !torch.str -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.eq.str$same_operand(
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true
// CHECK-NEXT: return %[[F]] : !torch.bool
func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool {
%0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.eq.str$same_value() -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
func @torch.aten.eq.str$same_value() -> !torch.bool {
%str4 = torch.constant.str "4"
%str4_0 = torch.constant.str "4"
%2 = torch.aten.eq.str %str4, %str4_0 : !torch.str, !torch.str -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.__not__
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool