From 22aeb967c5bc09d68ab93638214ff9c3b2a9be69 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Thu, 21 Oct 2021 00:52:53 -0400 Subject: [PATCH] Add ones --- e2e_testing/torchscript/basic.py | 32 ++++++++++- .../TorchToLinalg/TorchToLinalg.cpp | 56 +++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index d1c49acbe..588d10f49 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -396,4 +396,34 @@ class BroadcastToModule(torch.nn.Module): @register_test_case(module_factory=lambda: BroadcastToModule()) def BroadcastToModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 1, 1)) \ No newline at end of file + module.forward(tu.rand(3, 1, 1)) + +class OnesModuleInt(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.ones(3, 4, dtype=torch.int64) + +@register_test_case(module_factory=lambda: OnesModuleInt()) +def OnesModuleInt_basic(module, tu: TestUtils): + module.forward() + +class OnesModuleFloat(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.ones(3, 4, dtype=torch.float32) + +@register_test_case(module_factory=lambda: OnesModuleFloat()) +def OnesModuleFloat_basic(module, tu: TestUtils): + module.forward() diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index f40c72e61..c60075cc6 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -2233,6 +2233,60 @@ public: }; } // namespace +namespace { +class ConvertAtenOnesOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenOnesOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + AtenOnesOp::Adaptor adaptor(operands); + Location loc = op.getLoc(); + + // We ignore device, but add simple asserts for unimplemented kwargs + if (!adaptor.layout().getType().isa()) + return rewriter.notifyMatchFailure(op, + "only default layout is supported"); + bool pinMemory; + if (!adaptor.pin_memory().getType().isa() && + !matchPattern(adaptor.pin_memory(), m_TorchConstantBool(&pinMemory))) + return rewriter.notifyMatchFailure(op, "memory pinning not supported"); + + SmallVector size, sizeIndex; + if (!getListConstructElements(adaptor.size(), size)) { + return rewriter.notifyMatchFailure( + op, "size must be created by ListConstruct"); + } + size = getTypeConvertedValues(rewriter, loc, getTypeConverter(), size); + for (size_t i = 0; i < size.size(); i++) + sizeIndex.push_back(castIntToIndex(rewriter, loc, size[i])); + + RankedTensorType newResultType = + getTypeConverter()->convertType(op.getType()).cast(); + Type outElementType = newResultType.getElementType(); + + Value one = rewriter.create( + loc, outElementType, + (outElementType.isa() + ? rewriter.getFloatAttr(outElementType, 1).cast() + : rewriter.getIntegerAttr(outElementType, 1) + .cast())); + Value outTensor = rewriter + .create( + loc, sizeIndex, newResultType.getElementType()) + .getResult(); + Value fillOp = + rewriter.create(loc, one, outTensor).getResult(0); + + rewriter.replaceOpWithNewOp(op, newResultType, fillOp); + + return success(); + } +}; +} // namespace + // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- @@ -2302,6 +2356,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))