diff --git a/.gitignore b/.gitignore index 8d533fa5b..0deb56bb5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.swp +.cache/ .vscode .env *.code-workspace diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index d91c53de6..9e18ba704 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -11,7 +11,6 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== - class MmModule(torch.nn.Module): def __init__(self): super().__init__() @@ -38,7 +37,6 @@ def MmModule_chained(module, tu: TestUtils): # ============================================================================== - class BmmModule(torch.nn.Module): def __init__(self): super().__init__() @@ -57,10 +55,8 @@ class BmmModule(torch.nn.Module): def BmmModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) - # ============================================================================== - # A subgraph with multiple mm ops. class MmDagModule(torch.nn.Module): def __init__(self): @@ -80,10 +76,8 @@ class MmDagModule(torch.nn.Module): def MmDagModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 4), tu.rand(4, 4)) - # ============================================================================== - class MmTanhModule(torch.nn.Module): def __init__(self): super().__init__() @@ -100,8 +94,6 @@ class MmTanhModule(torch.nn.Module): def matmul(self, lhs, rhs): return torch.mm(lhs, rhs) -# ============================================================================== - @register_test_case(module_factory=lambda: MmTanhModule()) def MmTanhModule_basic(module, tu: TestUtils): @@ -109,7 +101,6 @@ def MmTanhModule_basic(module, tu: TestUtils): # ============================================================================== - class AddmmModuleFloat(torch.nn.Module): def __init__(self): super().__init__() @@ -196,7 +187,6 @@ def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils): # ============================================================================== - class FlattenStaticModule(torch.nn.Module): def __init__(self): super().__init__() @@ -217,7 +207,6 @@ def FlattenStaticModule_basic(module, tu: TestUtils): # ============================================================================== - class FlattenRank0Module(torch.nn.Module): def __init__(self): super().__init__() @@ -238,7 +227,6 @@ def FlattenRank0Module_basic(module, tu: TestUtils): # ============================================================================== - class FlattenDynamicModule(torch.nn.Module): def __init__(self): super().__init__() @@ -259,7 +247,6 @@ def FlattenDynamicModule_basic(module, tu: TestUtils): # ============================================================================== - class MaxPool2dModule(torch.nn.Module): def __init__(self): super().__init__() @@ -276,14 +263,86 @@ class MaxPool2dModule(torch.nn.Module): def forward(self, x): return self.mp2d(x) -# ============================================================================== - @register_test_case(module_factory=lambda: MaxPool2dModule()) def MaxPool2dModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20) - 0.5) +class ConstantPad2dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pad2d = torch.nn.ConstantPad2d((0, 1, 2, 3), -float('inf')) + + @export + @annotate_args([ + None, + ([1, 1, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return self.pad2d(x) + + +@register_test_case(module_factory=lambda: ConstantPad2dStaticModule()) +def ConstantPad2dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20) - 0.5) + +# ============================================================================== + +class ConstantPadNdModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf')) + + +@register_test_case(module_factory=lambda: ConstantPadNdModule()) +def ConstantPadNdModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5) + + +class ConstantPadNdStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 20, 20, 4, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf')) + + +@register_test_case(module_factory=lambda: ConstantPadNdStaticModule()) +def ConstantPadNdStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5) + +class ConstantPadNdPartialStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 20, 20, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.constant_pad_nd(x, (0, 1, 2, 3), -float('inf')) + + +@register_test_case(module_factory=lambda: ConstantPadNdPartialStaticModule()) +def ConstantPadNdPartialStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5) + +# ============================================================================== + class TransposeIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -296,13 +355,13 @@ class TransposeIntModule(torch.nn.Module): def forward(self, x): return torch.transpose(x, 0, 1) -# ============================================================================== - @register_test_case(module_factory=lambda: TransposeIntModule()) def TransposeIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) +# ============================================================================== + class PermuteModule(torch.nn.Module): def __init__(self): super().__init__() @@ -333,13 +392,12 @@ class TransposeIntNegDimsModule(torch.nn.Module): def forward(self, x): return torch.transpose(x, -1, -2) -# ============================================================================== - @register_test_case(module_factory=lambda: TransposeIntNegDimsModule()) def TransposeIntNegDimsModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) +# ============================================================================== class PermuteNegativeIndexModule(torch.nn.Module): def __init__(self): @@ -353,11 +411,12 @@ class PermuteNegativeIndexModule(torch.nn.Module): def forward(self, x): return x.permute(0, -1, 1) -# ============================================================================== - @register_test_case(module_factory=lambda: PermuteNegativeIndexModule()) def PermuteNegativeIndexModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) + +# ============================================================================== + class TensorsConcatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -379,7 +438,6 @@ def TensorsConcatModule_basic(module, tu: TestUtils): # ============================================================================== - class GatherModule(torch.nn.Module): def __init__(self): super().__init__() @@ -422,7 +480,6 @@ def AddSizeIntModule_basic(module, tu: TestUtils): # ============================================================================== - class AddSizeIntNegDimModule(torch.nn.Module): def __init__(self): super().__init__() @@ -505,7 +562,6 @@ def _SoftmaxModule_basic(module, tu: TestUtils): # ============================================================================== - class SoftmaxIntNegDimModule(torch.nn.Module): def __init__(self): super().__init__() @@ -527,7 +583,6 @@ def SoftmaxIntNegDimModule_basic(module, tu: TestUtils): # ============================================================================== - class SoftmaxIntArgTypeF64Module(torch.nn.Module): def __init__(self): super().__init__() diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 8fc76aa24..2dbc9e18b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1778,6 +1778,22 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)"; } +def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchIntListType:$pad, + AnyTorchScalarType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $pad `,` $value attr-dict `:` type($self) `,` type($pad) `,` type($value) `->` type($result)"; +} + def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [ AllowsTypeRefinement ]> { @@ -2915,6 +2931,22 @@ def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [ let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; } +def Torch_AtenEqStrOp : Torch_Op<"aten.eq.str", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::eq.str : (str, str) -> (bool)`"; + let arguments = (ins + Torch_StringType:$a, + Torch_StringType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; + let hasFolder = 1; +} + def Torch_AtenStrOp : Torch_Op<"aten.str", [ AllowsTypeRefinement, HasValueSemantics @@ -3175,6 +3207,21 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [ let hasFolder = 1; } +def Torch_AtenNegIntOp : Torch_Op<"aten.neg.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::neg.int : (int) -> (int)`"; + let arguments = (ins + Torch_IntType:$a + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)"; + let hasFolder = 1; +} + def Torch_AtenLogIntOp : Torch_Op<"aten.log.int", [ AllowsTypeRefinement, HasValueSemantics @@ -3248,6 +3295,22 @@ def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [ let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; } +def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::eq.float : (float, float) -> (bool)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_FloatType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; + let hasFolder = 1; +} + def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ AllowsTypeRefinement, HasValueSemantics diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td index d7dc8d80a..a83e26379 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td @@ -185,6 +185,7 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [ AnyTorchType:$result ); let assemblyFormat = " attr-dict `:` type($result)"; + let hasCanonicalizer = 1; } def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [ diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index 1c556a11a..3580d0baa 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -45,6 +45,21 @@ struct torch_constant_int_op_binder { return false; } }; + +struct torch_constant_float_op_binder { + double *bind_value; + + /// Creates a matcher instance that binds the value to bv if match succeeds. + torch_constant_float_op_binder(double *bv) : bind_value(bv) {} + + bool match(Operation *op) { + if (auto constantFloat = dyn_cast(op)) { + *bind_value = constantFloat.value().convertToDouble(); + return true; + } + return false; + } +}; } // namespace detail /// Matches the integer stored in a `torch.constant.bool`. @@ -53,6 +68,12 @@ m_TorchConstantInt(int64_t *bind_value) { return detail::torch_constant_int_op_binder(bind_value); } +/// Matches the float value stored in a `torch.constant.float`. +inline detail::torch_constant_float_op_binder +m_TorchConstantFloat(double *bind_value) { + return detail::torch_constant_float_op_binder(bind_value); +} + namespace detail { /// Matches the bool stored in a `torch.constant.bool`. struct torch_constant_bool_op_binder { diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index a44942616..1405d13c7 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -275,7 +275,27 @@ static SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, } // Helper function to get the padding tensor given the padding int values. -// It's assumed that the padding on the low end and high end are the same. +static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input, + SmallVectorImpl &lowPaddingInts, + SmallVectorImpl &highPaddingInts, + Value pad) { + Location loc = op->getLoc(); + Type rankedTensorType = linalg::PadTensorOp::inferResultType( + input.getType().cast(), lowPaddingInts, + highPaddingInts); + SmallVector lowPaddings = + getAsOpFoldResult(b, loc, lowPaddingInts); + SmallVector highPaddings = + getAsOpFoldResult(b, loc, highPaddingInts); + Value paddedInput = linalg::PadTensorOp::createPadScalarOp( + rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings, + /*packing=*/false, loc, b); + return paddedInput; +} + +// Helper function to get the padding tensor given the padding int values. +// It's assumed that the padding on the low end and high end are the same, +// and that zero padding is required. static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &paddingInts) { assert(input.getType().isa() && @@ -284,13 +304,7 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input, Value c0 = b.create( loc, b.getZeroAttr(input.getType().cast().getElementType())); - SmallVector paddings = getAsOpFoldResult(b, loc, paddingInts); - Type ranked4DTensorType = linalg::PadTensorOp::inferResultType( - input.getType().cast(), paddingInts, paddingInts); - Value paddedInput = linalg::PadTensorOp::createPadScalarOp( - ranked4DTensorType, input, c0, /*low=*/paddings, /*high=*/paddings, - /*packing=*/false, loc, b); - return paddedInput; + return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0); } static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean, @@ -2685,6 +2699,57 @@ public: }; } // namespace +namespace { +class ConvertAtenConstantPadNdOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenConstantPadNdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op->getLoc(); + Value self = adaptor.self(); + auto type = self.getType().cast(); + int64_t rank = type.getRank(); + + // Pattern match against the op's original operands, because otherwise we + // will get the lowered version of the operands which is harder to pattern + // match. + SmallVector padInts; + if (!matchPattern(op.pad(), m_TorchConstantIntList(padInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int pad ranges"); + uint64_t padRank = padInts.size() / 2; + if (padRank * 2 != padInts.size()) + return rewriter.notifyMatchFailure(op, "pad range size is not even"); + if (rank < 0 || padRank > (uint64_t)rank) + return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); + + // Initialize low/high paddings with the dims that should not be padded. + SmallVector lowPadding(/*Size=*/rank - padRank, /*Value=*/0); + SmallVector highPadding(/*Size=*/rank - padRank, /*Value=*/0); + // Add the requested padding - note op.pad() is highest dim first ordered + // pairs of low,high. + for (uint64_t i = padRank; i > 0; --i) { + lowPadding.push_back(padInts[i * 2 - 2]); + highPadding.push_back(padInts[i * 2 - 1]); + } + + Type newResultType = getTypeConverter()->convertType(op.getType()); + Type elementType = newResultType.cast().getElementType(); + Value castedValue = + convertScalarToDtype(rewriter, loc, adaptor.value(), elementType); + Value paddedInput = getPaddedTensor(op, rewriter, self, lowPadding, + highPadding, castedValue); + + rewriter.replaceOpWithNewOp(op, newResultType, paddedInput); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { @@ -4225,6 +4290,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 51f90159f..f43eaefe4 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -13,8 +13,10 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/StringMap.h" +#include "llvm/Support/Casting.h" using namespace mlir; using namespace mlir::torch; @@ -653,6 +655,36 @@ OpFoldResult AtenEqIntOp::fold(ArrayRef operands) { [](int64_t a, int64_t b) { return a == b; }); } +//===----------------------------------------------------------------------===// +// AtenEqFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenEqFloatOp::fold(ArrayRef operands) { + double lhs, rhs; + + if (!matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)) || + !matchPattern(getOperand(1), m_TorchConstantFloat(&rhs))) + return nullptr; + + return getI1IntegerAttr(getContext(), lhs == rhs); +} + +//===----------------------------------------------------------------------===// +// AtenEqStrOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenEqStrOp::fold(ArrayRef operands) { + if (getOperand(0) == getOperand(1)) + return getI1IntegerAttr(getContext(), true); + + auto aStr = a().getDefiningOp(); + auto bStr = b().getDefiningOp(); + + if (aStr && bStr) + return getI1IntegerAttr(getContext(), aStr == bStr); + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenLtIntOp //===----------------------------------------------------------------------===// @@ -1005,6 +1037,20 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// PrimUninitializedOp +//===----------------------------------------------------------------------===// + +void PrimUninitializedOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) { + if (!op.use_empty()) + return failure(); + rewriter.eraseOp(op); + return success(); + }); +} + //===----------------------------------------------------------------------===// // PrimTupleUnpackOp //===----------------------------------------------------------------------===// @@ -1129,6 +1175,17 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenNegIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenNegIntOp::fold(ArrayRef operands) { + int64_t c; + if (matchPattern(getOperand(), m_TorchConstantInt(&c))) + return getI64IntegerAttr(getContext(), -c); + return nullptr; +} + //===----------------------------------------------------------------------===// // PrimDtypeOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index c4e4af2fe..1cb649012 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -490,6 +490,8 @@ public: return visitAtenNllLossForwardOp(nllForwardOp, operands); } else if (auto nativeLayerNormOp = dyn_cast(op)) { return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands); + } else if (auto constantPadNdOp = dyn_cast(op)) { + return visitAtenConstantPadNdOp(constantPadNdOp, operands); } // Otherwise, this is an unknown operation. Just mark all results as @@ -513,6 +515,9 @@ private: ChangeResult visitAtenMaxPool2dOp(AtenMaxPool2dOp op, ArrayRef *> operands); + ChangeResult + visitAtenConstantPadNdOp(AtenConstantPadNdOp op, + ArrayRef *> operands); ChangeResult visitAtenAdaptiveAvgPool2dOp( AtenAdaptiveAvgPool2dOp op, ArrayRef *> operands); @@ -920,18 +925,18 @@ ChangeResult TypeAnalyzer::visitAtenConv2dOp( auto knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); knowledge.hasSizes = true; - auto &ifm = operands[0]->getValue(); + auto &input = operands[0]->getValue(); auto &weights = operands[1]->getValue(); - if (weights.hasSizes && ifm.hasSizes) + if (weights.hasSizes && input.hasSizes) knowledge.sizes = computeOpWithKernelOutputShape( - op, ifm, weights.sizes[0], weights.sizes[2], weights.sizes[3]); + op, input, weights.sizes[0], weights.sizes[2], weights.sizes[3]); else knowledge.sizes.resize(4, kUnknownSize); // Running some experiments in PyTorch, the bias doesn't seem to // contribute to the final element type. - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(op->getContext(), - {&ifm, &weights}); + knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( + op->getContext(), {&input, &weights}); return getLatticeElement(op->getResult(0)).join(knowledge); } @@ -940,19 +945,45 @@ ChangeResult TypeAnalyzer::visitAtenMaxPool2dOp( auto knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); knowledge.hasSizes = true; - auto &ifm = operands[0]->getValue(); + auto &input = operands[0]->getValue(); SmallVector kernelSize; if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize))) kernelSize = SmallVector{kUnknownSize, kUnknownSize}; - if (ifm.hasSizes) + if (input.hasSizes) knowledge.sizes = computeOpWithKernelOutputShape( - op, ifm, ifm.sizes[1], kernelSize[0], kernelSize[1]); + op, input, input.sizes[1], kernelSize[0], kernelSize[1]); else knowledge.sizes.resize(4, kUnknownSize); knowledge.dtype = operands[0]->getValue().dtype; return getLatticeElement(op->getResult(0)).join(knowledge); } +ChangeResult TypeAnalyzer::visitAtenConstantPadNdOp( + AtenConstantPadNdOp op, + ArrayRef *> operands) { + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + auto &input = operands[0]->getValue(); + if (input.hasSizes) { + knowledge.hasSizes = true; + SmallVector padInts; + if (matchPattern(op.pad(), m_TorchConstantIntList(padInts))) { + knowledge.sizes = input.sizes; + uint64_t padRank = padInts.size() / 2; + uint64_t padOffset = knowledge.sizes.size() - padRank; + // op.pad() is highest dim first ordered pairs of low,high. + for (uint64_t i = padRank, r = padOffset; i > 0; --i, ++r) { + if (knowledge.sizes[r] != kUnknownSize) + knowledge.sizes[r] += padInts[i * 2 - 2] + padInts[i * 2 - 1]; + } + } else + knowledge.sizes.resize(input.sizes.size(), kUnknownSize); + } + + knowledge.dtype = operands[0]->getValue().dtype; + return getLatticeElement(op->getResult(0)).join(knowledge); +} + ChangeResult TypeAnalyzer::visitAtenAdaptiveAvgPool2dOp( AtenAdaptiveAvgPool2dOp op, ArrayRef *> operands) { diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index b7a518c4f..845456b47 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -414,7 +414,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry): emit("prim::max.self_int : (int[]) -> (int)") emit("prim::max.int : (int, int) -> (int)") emit("prim::RaiseException : (str) -> ()") - emit("prim::Uninitialized : () -> (Any)") + emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True) emit("prim::unchecked_cast : (t) -> (t)", traits=["DeclareOpInterfaceMethods"]) emit("prim::Print : (...) -> ()") @@ -540,6 +540,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") # Misc tensor ops. + emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::unsqueeze : (Tensor, int) -> (Tensor)") emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) @@ -619,6 +620,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): # Str ops. emit("aten::add.str : (str, str) -> (str)") + emit("aten::eq.str : (str, str) -> (bool)", has_folder=True) emit("aten::str : (t) -> (str)") emit("aten::format : (...) -> (str)") emit("aten::join : (str, str[]) -> (str)") @@ -640,11 +642,13 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True) emit("aten::mul.int : (int, int) -> (int)", has_folder=True) + emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)") emit("aten::mul.float : (float, float) -> (float)") emit("aten::neg.float : (float) -> (float)") emit("aten::lt.float_int : (float, int) -> (bool)") + emit("aten::eq.float : (float, float) -> (bool)", has_folder=True) emit("aten::__and__.bool : (bool, bool) -> (bool)") emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 2c3ab00a3..961825ba1 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -249,6 +249,55 @@ func @torch.aten.ge.int$same_value() -> !torch.bool { return %2 : !torch.bool } +// CHECK-LABEL: func @torch.aten.eq.float$different_value() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func @torch.aten.eq.float$different_value() -> !torch.bool { + %float4 = torch.constant.float 4.0 + %float5 = torch.constant.float 5.0 + %2 = torch.aten.eq.float %float4, %float5 : !torch.float, !torch.float -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.eq.float$same_value() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.eq.float$same_value() -> !torch.bool { + %float4 = torch.constant.float 4.0 + %float4_0 = torch.constant.float 4.0 + %2 = torch.aten.eq.float %float4, %float4_0 : !torch.float, !torch.float -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.eq.str$different_value() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func @torch.aten.eq.str$different_value() -> !torch.bool { + %str4 = torch.constant.str "4" + %str5 = torch.constant.str "5" + %2 = torch.aten.eq.str %str4, %str5 : !torch.str, !torch.str -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.eq.str$same_operand( +// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { +// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true +// CHECK-NEXT: return %[[F]] : !torch.bool +func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool { + %0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.eq.str$same_value() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.eq.str$same_value() -> !torch.bool { + %str4 = torch.constant.str "4" + %str4_0 = torch.constant.str "4" + %2 = torch.aten.eq.str %str4, %str4_0 : !torch.str, !torch.str -> !torch.bool + return %2 : !torch.bool +} + // CHECK-LABEL: func @torch.aten.__not__ // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool