mirror of https://github.com/llvm/torch-mlir
Add support for constant_pad_nd
Note that to enable folding of the code coming from an example like the ConstantPad2dStaticModule e2e test, support for other operations had to be added/improved: - aten::neg.int - aten::eq.float - aten::eq.str - prim::Uninitializedpull/519/head snapshot-20220111.200
parent
35cf8d18f7
commit
077e55d756
|
@ -1,4 +1,5 @@
|
|||
*.swp
|
||||
.cache/
|
||||
.vscode
|
||||
.env
|
||||
*.code-workspace
|
||||
|
|
|
@ -11,7 +11,6 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -38,7 +37,6 @@ def MmModule_chained(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BmmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -57,10 +55,8 @@ class BmmModule(torch.nn.Module):
|
|||
def BmmModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# A subgraph with multiple mm ops.
|
||||
class MmDagModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -80,10 +76,8 @@ class MmDagModule(torch.nn.Module):
|
|||
def MmDagModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MmTanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -100,8 +94,6 @@ class MmTanhModule(torch.nn.Module):
|
|||
def matmul(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MmTanhModule())
|
||||
def MmTanhModule_basic(module, tu: TestUtils):
|
||||
|
@ -109,7 +101,6 @@ def MmTanhModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AddmmModuleFloat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -196,7 +187,6 @@ def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class FlattenStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -217,7 +207,6 @@ def FlattenStaticModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class FlattenRank0Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -238,7 +227,6 @@ def FlattenRank0Module_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class FlattenDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -259,7 +247,6 @@ def FlattenDynamicModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MaxPool2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -276,14 +263,86 @@ class MaxPool2dModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return self.mp2d(x)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxPool2dModule())
|
||||
def MaxPool2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
|
||||
|
||||
|
||||
class ConstantPad2dStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pad2d = torch.nn.ConstantPad2d((0, 1, 2, 3), -float('inf'))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 20, 20], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.pad2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConstantPad2dStaticModule())
|
||||
def ConstantPad2dStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ConstantPadNdModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf'))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConstantPadNdModule())
|
||||
def ConstantPadNdModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
|
||||
|
||||
|
||||
class ConstantPadNdStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 20, 20, 4, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf'))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConstantPadNdStaticModule())
|
||||
def ConstantPadNdStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
|
||||
|
||||
class ConstantPadNdPartialStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 20, 20, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.constant_pad_nd(x, (0, 1, 2, 3), -float('inf'))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConstantPadNdPartialStaticModule())
|
||||
def ConstantPadNdPartialStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TransposeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -296,13 +355,13 @@ class TransposeIntModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.transpose(x, 0, 1)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TransposeIntModule())
|
||||
def TransposeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class PermuteModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -333,13 +392,12 @@ class TransposeIntNegDimsModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.transpose(x, -1, -2)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TransposeIntNegDimsModule())
|
||||
def TransposeIntNegDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class PermuteNegativeIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -353,11 +411,12 @@ class PermuteNegativeIndexModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return x.permute(0, -1, 1)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
@register_test_case(module_factory=lambda: PermuteNegativeIndexModule())
|
||||
def PermuteNegativeIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TensorsConcatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -379,7 +438,6 @@ def TensorsConcatModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -422,7 +480,6 @@ def AddSizeIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AddSizeIntNegDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -505,7 +562,6 @@ def _SoftmaxModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SoftmaxIntNegDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -527,7 +583,6 @@ def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SoftmaxIntArgTypeF64Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -1778,6 +1778,22 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
|
|||
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)";
|
||||
}
|
||||
|
||||
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$pad,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $pad `,` $value attr-dict `:` type($self) `,` type($pad) `,` type($value) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
@ -2915,6 +2931,22 @@ def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [
|
|||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenEqStrOp : Torch_Op<"aten.eq.str", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::eq.str : (str, str) -> (bool)`";
|
||||
let arguments = (ins
|
||||
Torch_StringType:$a,
|
||||
Torch_StringType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenStrOp : Torch_Op<"aten.str", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -3175,6 +3207,21 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenNegIntOp : Torch_Op<"aten.neg.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::neg.int : (int) -> (int)`";
|
||||
let arguments = (ins
|
||||
Torch_IntType:$a
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenLogIntOp : Torch_Op<"aten.log.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -3248,6 +3295,22 @@ def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [
|
|||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::eq.float : (float, float) -> (bool)`";
|
||||
let arguments = (ins
|
||||
Torch_FloatType:$a,
|
||||
Torch_FloatType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -185,6 +185,7 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
|
|||
AnyTorchType:$result
|
||||
);
|
||||
let assemblyFormat = " attr-dict `:` type($result)";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [
|
||||
|
|
|
@ -45,6 +45,21 @@ struct torch_constant_int_op_binder {
|
|||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct torch_constant_float_op_binder {
|
||||
double *bind_value;
|
||||
|
||||
/// Creates a matcher instance that binds the value to bv if match succeeds.
|
||||
torch_constant_float_op_binder(double *bv) : bind_value(bv) {}
|
||||
|
||||
bool match(Operation *op) {
|
||||
if (auto constantFloat = dyn_cast<Torch::ConstantFloatOp>(op)) {
|
||||
*bind_value = constantFloat.value().convertToDouble();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Matches the integer stored in a `torch.constant.bool`.
|
||||
|
@ -53,6 +68,12 @@ m_TorchConstantInt(int64_t *bind_value) {
|
|||
return detail::torch_constant_int_op_binder(bind_value);
|
||||
}
|
||||
|
||||
/// Matches the float value stored in a `torch.constant.float`.
|
||||
inline detail::torch_constant_float_op_binder
|
||||
m_TorchConstantFloat(double *bind_value) {
|
||||
return detail::torch_constant_float_op_binder(bind_value);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
/// Matches the bool stored in a `torch.constant.bool`.
|
||||
struct torch_constant_bool_op_binder {
|
||||
|
|
|
@ -275,7 +275,27 @@ static SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
// Helper function to get the padding tensor given the padding int values.
|
||||
// It's assumed that the padding on the low end and high end are the same.
|
||||
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
||||
SmallVectorImpl<int64_t> &lowPaddingInts,
|
||||
SmallVectorImpl<int64_t> &highPaddingInts,
|
||||
Value pad) {
|
||||
Location loc = op->getLoc();
|
||||
Type rankedTensorType = linalg::PadTensorOp::inferResultType(
|
||||
input.getType().cast<RankedTensorType>(), lowPaddingInts,
|
||||
highPaddingInts);
|
||||
SmallVector<OpFoldResult> lowPaddings =
|
||||
getAsOpFoldResult(b, loc, lowPaddingInts);
|
||||
SmallVector<OpFoldResult> highPaddings =
|
||||
getAsOpFoldResult(b, loc, highPaddingInts);
|
||||
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
|
||||
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
|
||||
/*packing=*/false, loc, b);
|
||||
return paddedInput;
|
||||
}
|
||||
|
||||
// Helper function to get the padding tensor given the padding int values.
|
||||
// It's assumed that the padding on the low end and high end are the same,
|
||||
// and that zero padding is required.
|
||||
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
||||
SmallVectorImpl<int64_t> &paddingInts) {
|
||||
assert(input.getType().isa<RankedTensorType>() &&
|
||||
|
@ -284,13 +304,7 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
|||
Value c0 = b.create<arith::ConstantOp>(
|
||||
loc,
|
||||
b.getZeroAttr(input.getType().cast<RankedTensorType>().getElementType()));
|
||||
SmallVector<OpFoldResult> paddings = getAsOpFoldResult(b, loc, paddingInts);
|
||||
Type ranked4DTensorType = linalg::PadTensorOp::inferResultType(
|
||||
input.getType().cast<RankedTensorType>(), paddingInts, paddingInts);
|
||||
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
|
||||
ranked4DTensorType, input, c0, /*low=*/paddings, /*high=*/paddings,
|
||||
/*packing=*/false, loc, b);
|
||||
return paddedInput;
|
||||
return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0);
|
||||
}
|
||||
|
||||
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
|
||||
|
@ -2685,6 +2699,57 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenConstantPadNdOp
|
||||
: public OpConversionPattern<AtenConstantPadNdOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenConstantPadNdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
Value self = adaptor.self();
|
||||
auto type = self.getType().cast<RankedTensorType>();
|
||||
int64_t rank = type.getRank();
|
||||
|
||||
// Pattern match against the op's original operands, because otherwise we
|
||||
// will get the lowered version of the operands which is harder to pattern
|
||||
// match.
|
||||
SmallVector<int64_t> padInts;
|
||||
if (!matchPattern(op.pad(), m_TorchConstantIntList(padInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant int pad ranges");
|
||||
uint64_t padRank = padInts.size() / 2;
|
||||
if (padRank * 2 != padInts.size())
|
||||
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
||||
if (rank < 0 || padRank > (uint64_t)rank)
|
||||
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
|
||||
|
||||
// Initialize low/high paddings with the dims that should not be padded.
|
||||
SmallVector<int64_t, 4> lowPadding(/*Size=*/rank - padRank, /*Value=*/0);
|
||||
SmallVector<int64_t, 4> highPadding(/*Size=*/rank - padRank, /*Value=*/0);
|
||||
// Add the requested padding - note op.pad() is highest dim first ordered
|
||||
// pairs of low,high.
|
||||
for (uint64_t i = padRank; i > 0; --i) {
|
||||
lowPadding.push_back(padInts[i * 2 - 2]);
|
||||
highPadding.push_back(padInts[i * 2 - 1]);
|
||||
}
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type elementType = newResultType.cast<RankedTensorType>().getElementType();
|
||||
Value castedValue =
|
||||
convertScalarToDtype(rewriter, loc, adaptor.value(), elementType);
|
||||
Value paddedInput = getPaddedTensor(op, rewriter, self, lowPadding,
|
||||
highPadding, castedValue);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenFlattenUsingIntsOp
|
||||
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
|
||||
|
@ -4225,6 +4290,8 @@ public:
|
|||
patterns.add<ConvertAtenViewOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenConstantPadNdOp>();
|
||||
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSumOp>();
|
||||
patterns.add<ConvertReductionOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenTransposeIntOp>();
|
||||
|
|
|
@ -13,8 +13,10 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -653,6 +655,36 @@ OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
[](int64_t a, int64_t b) { return a == b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenEqFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
double lhs, rhs;
|
||||
|
||||
if (!matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)) ||
|
||||
!matchPattern(getOperand(1), m_TorchConstantFloat(&rhs)))
|
||||
return nullptr;
|
||||
|
||||
return getI1IntegerAttr(getContext(), lhs == rhs);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenEqStrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (getOperand(0) == getOperand(1))
|
||||
return getI1IntegerAttr(getContext(), true);
|
||||
|
||||
auto aStr = a().getDefiningOp<ConstantStrOp>();
|
||||
auto bStr = b().getDefiningOp<ConstantStrOp>();
|
||||
|
||||
if (aStr && bStr)
|
||||
return getI1IntegerAttr(getContext(), aStr == bStr);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenLtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1005,6 +1037,20 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimUninitializedOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PrimUninitializedOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) {
|
||||
if (!op.use_empty())
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimTupleUnpackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1129,6 +1175,17 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenNegIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
int64_t c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
||||
return getI64IntegerAttr(getContext(), -c);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimDtypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -490,6 +490,8 @@ public:
|
|||
return visitAtenNllLossForwardOp(nllForwardOp, operands);
|
||||
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
|
||||
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
|
||||
} else if (auto constantPadNdOp = dyn_cast<AtenConstantPadNdOp>(op)) {
|
||||
return visitAtenConstantPadNdOp(constantPadNdOp, operands);
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
|
@ -513,6 +515,9 @@ private:
|
|||
ChangeResult
|
||||
visitAtenMaxPool2dOp(AtenMaxPool2dOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenConstantPadNdOp(AtenConstantPadNdOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult visitAtenAdaptiveAvgPool2dOp(
|
||||
AtenAdaptiveAvgPool2dOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
@ -920,18 +925,18 @@ ChangeResult TypeAnalyzer::visitAtenConv2dOp(
|
|||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
auto &ifm = operands[0]->getValue();
|
||||
auto &input = operands[0]->getValue();
|
||||
auto &weights = operands[1]->getValue();
|
||||
if (weights.hasSizes && ifm.hasSizes)
|
||||
if (weights.hasSizes && input.hasSizes)
|
||||
knowledge.sizes = computeOpWithKernelOutputShape(
|
||||
op, ifm, weights.sizes[0], weights.sizes[2], weights.sizes[3]);
|
||||
op, input, weights.sizes[0], weights.sizes[2], weights.sizes[3]);
|
||||
else
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
|
||||
// Running some experiments in PyTorch, the bias doesn't seem to
|
||||
// contribute to the final element type.
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(op->getContext(),
|
||||
{&ifm, &weights});
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
op->getContext(), {&input, &weights});
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
|
@ -940,19 +945,45 @@ ChangeResult TypeAnalyzer::visitAtenMaxPool2dOp(
|
|||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
auto &ifm = operands[0]->getValue();
|
||||
auto &input = operands[0]->getValue();
|
||||
SmallVector<int64_t, 2> kernelSize;
|
||||
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))
|
||||
kernelSize = SmallVector<int64_t, 2>{kUnknownSize, kUnknownSize};
|
||||
if (ifm.hasSizes)
|
||||
if (input.hasSizes)
|
||||
knowledge.sizes = computeOpWithKernelOutputShape(
|
||||
op, ifm, ifm.sizes[1], kernelSize[0], kernelSize[1]);
|
||||
op, input, input.sizes[1], kernelSize[0], kernelSize[1]);
|
||||
else
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
knowledge.dtype = operands[0]->getValue().dtype;
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenConstantPadNdOp(
|
||||
AtenConstantPadNdOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
auto &input = operands[0]->getValue();
|
||||
if (input.hasSizes) {
|
||||
knowledge.hasSizes = true;
|
||||
SmallVector<int64_t> padInts;
|
||||
if (matchPattern(op.pad(), m_TorchConstantIntList(padInts))) {
|
||||
knowledge.sizes = input.sizes;
|
||||
uint64_t padRank = padInts.size() / 2;
|
||||
uint64_t padOffset = knowledge.sizes.size() - padRank;
|
||||
// op.pad() is highest dim first ordered pairs of low,high.
|
||||
for (uint64_t i = padRank, r = padOffset; i > 0; --i, ++r) {
|
||||
if (knowledge.sizes[r] != kUnknownSize)
|
||||
knowledge.sizes[r] += padInts[i * 2 - 2] + padInts[i * 2 - 1];
|
||||
}
|
||||
} else
|
||||
knowledge.sizes.resize(input.sizes.size(), kUnknownSize);
|
||||
}
|
||||
|
||||
knowledge.dtype = operands[0]->getValue().dtype;
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenAdaptiveAvgPool2dOp(
|
||||
AtenAdaptiveAvgPool2dOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
|
|
|
@ -414,7 +414,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("prim::max.self_int : (int[]) -> (int)")
|
||||
emit("prim::max.int : (int, int) -> (int)")
|
||||
emit("prim::RaiseException : (str) -> ()")
|
||||
emit("prim::Uninitialized : () -> (Any)")
|
||||
emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True)
|
||||
emit("prim::unchecked_cast : (t) -> (t)",
|
||||
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
|
||||
emit("prim::Print : (...) -> ()")
|
||||
|
@ -540,6 +540,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
|
||||
|
@ -619,6 +620,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
|
||||
# Str ops.
|
||||
emit("aten::add.str : (str, str) -> (str)")
|
||||
emit("aten::eq.str : (str, str) -> (bool)", has_folder=True)
|
||||
emit("aten::str : (t) -> (str)")
|
||||
emit("aten::format : (...) -> (str)")
|
||||
emit("aten::join : (str, str[]) -> (str)")
|
||||
|
@ -640,11 +642,13 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::neg.int : (int) -> (int)", has_folder=True)
|
||||
emit("aten::log.int : (int) -> (float)")
|
||||
emit("aten::add.float_int : (float, int) -> (float)")
|
||||
emit("aten::mul.float : (float, float) -> (float)")
|
||||
emit("aten::neg.float : (float) -> (float)")
|
||||
emit("aten::lt.float_int : (float, int) -> (bool)")
|
||||
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
|
||||
emit("aten::__and__.bool : (bool, bool) -> (bool)")
|
||||
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
|
||||
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
||||
|
|
|
@ -249,6 +249,55 @@ func @torch.aten.ge.int$same_value() -> !torch.bool {
|
|||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.eq.float$different_value() -> !torch.bool {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: return %[[FALSE]] : !torch.bool
|
||||
func @torch.aten.eq.float$different_value() -> !torch.bool {
|
||||
%float4 = torch.constant.float 4.0
|
||||
%float5 = torch.constant.float 5.0
|
||||
%2 = torch.aten.eq.float %float4, %float5 : !torch.float, !torch.float -> !torch.bool
|
||||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.eq.float$same_value() -> !torch.bool {
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
func @torch.aten.eq.float$same_value() -> !torch.bool {
|
||||
%float4 = torch.constant.float 4.0
|
||||
%float4_0 = torch.constant.float 4.0
|
||||
%2 = torch.aten.eq.float %float4, %float4_0 : !torch.float, !torch.float -> !torch.bool
|
||||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.eq.str$different_value() -> !torch.bool {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: return %[[FALSE]] : !torch.bool
|
||||
func @torch.aten.eq.str$different_value() -> !torch.bool {
|
||||
%str4 = torch.constant.str "4"
|
||||
%str5 = torch.constant.str "5"
|
||||
%2 = torch.aten.eq.str %str4, %str5 : !torch.str, !torch.str -> !torch.bool
|
||||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.eq.str$same_operand(
|
||||
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
||||
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true
|
||||
// CHECK-NEXT: return %[[F]] : !torch.bool
|
||||
func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
||||
%0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.eq.str$same_value() -> !torch.bool {
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
func @torch.aten.eq.str$same_value() -> !torch.bool {
|
||||
%str4 = torch.constant.str "4"
|
||||
%str4_0 = torch.constant.str "4"
|
||||
%2 = torch.aten.eq.str %str4, %str4_0 : !torch.str, !torch.str -> !torch.bool
|
||||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.__not__
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
|
|
Loading…
Reference in New Issue