diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index fb4147601..440cf9f30 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -900,7 +900,6 @@ class NumelModule(torch.nn.Module): None, ([-1, -1, -1], torch.float32, True), ]) - def forward(self, input): return torch.numel(input) @@ -918,10 +917,89 @@ class NumelZeroRankModule(torch.nn.Module): None, ([], torch.int64, True), ]) - def forward(self, input): return torch.numel(input) @register_test_case(module_factory=lambda: NumelZeroRankModule()) def NumelZeroRankModule_basic(module, tu: TestUtils): module.forward(torch.randint(10,[])) + + +class ZerosModuleInt2D(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.zeros(3, 4, dtype=torch.int64) + +@register_test_case(module_factory=lambda: ZerosModuleInt2D()) +def ZerosModuleInt2D_basic(module, tu: TestUtils): + module.forward() + + +class ZerosModuleInt3D(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.zeros(3, 4, 5, dtype=torch.int64) + +@register_test_case(module_factory=lambda: ZerosModuleInt3D()) +def ZerosModuleInt3D_basic(module, tu: TestUtils): + module.forward() + + +class ZerosModuleFloat2D(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.zeros(3, 4, dtype=torch.float32) + +@register_test_case(module_factory=lambda: ZerosModuleFloat2D()) +def ZerosModuleFloat2D_basic(module, tu: TestUtils): + module.forward() + + +class ZerosModuleFloat3D(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.zeros(3, 4, 5, dtype=torch.float32) + +@register_test_case(module_factory=lambda: ZerosModuleFloat3D()) +def ZerosModuleFloat3D_basic(module, tu: TestUtils): + module.forward() + + +class ZerosModuleFalsePinMemory(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False) + +@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory()) +def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils): + module.forward() diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index d485fe7a1..42e059ea1 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -3269,24 +3269,44 @@ public: } // namespace namespace { -class ConvertAtenOnesOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +// Converts AtenOnesOp and AtenZerosOp. +struct ConvertAtenOnesZerosOp : ConversionPattern { + ConvertAtenOnesZerosOp(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, + context) {} + LogicalResult - matchAndRewrite(AtenOnesOp op, OpAdaptor adaptor, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + if (!isa(op)) + return rewriter.notifyMatchFailure(op, + "not a supported ones or zeros op"); + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - Location loc = op.getLoc(); + Location loc = op->getLoc(); + + SmallVector opArguments; + int64_t elementValue; + + if (AtenOnesOp onesOp = dyn_cast(op)) { + opArguments.insert(opArguments.end(), + {onesOp.size(), onesOp.layout(), onesOp.pin_memory()}); + elementValue = 1; + } else if (AtenZerosOp zerosOp = dyn_cast(op)) { + opArguments.insert(opArguments.end(), {zerosOp.size(), zerosOp.layout(), + zerosOp.pin_memory()}); + elementValue = 0; + } // We ignore device, but add simple asserts for unimplemented kwargs - if (!op.layout().getType().isa()) + if (!opArguments[1].getType().isa()) return rewriter.notifyMatchFailure(op, "only default layout is supported"); bool pinMemory = false; - if (!op.pin_memory().getType().isa() && - !matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory))) { + if (!opArguments[2].getType().isa() && + !matchPattern(opArguments[2], m_TorchConstantBool(&pinMemory))) { return rewriter.notifyMatchFailure( op, "pin_memory must be constant bool or None"); } @@ -3294,7 +3314,7 @@ public: return rewriter.notifyMatchFailure(op, "memory pinning not supported"); SmallVector size, sizeIndex; - if (!getListConstructElements(op.size(), size)) { + if (!getListConstructElements(opArguments[0], size)) { return rewriter.notifyMatchFailure( op, "size must be created by ListConstruct"); } @@ -3303,21 +3323,24 @@ public: sizeIndex.push_back(castIntToIndex(rewriter, loc, size[i])); RankedTensorType newResultType = - getTypeConverter()->convertType(op.getType()).cast(); + getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); Type outElementType = newResultType.getElementType(); - Value one = rewriter.create( + Value constantOp = rewriter.create( loc, outElementType, (outElementType.isa() - ? rewriter.getFloatAttr(outElementType, 1).cast() - : rewriter.getIntegerAttr(outElementType, 1) + ? rewriter.getFloatAttr(outElementType, elementValue) + .cast() + : rewriter.getIntegerAttr(outElementType, elementValue) .cast())); Value outTensor = rewriter .create( loc, sizeIndex, newResultType.getElementType()) .getResult(); - Value fillOp = - rewriter.create(loc, one, outTensor).getResult(0); + Value fillOp = rewriter.create(loc, constantOp, outTensor) + .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, fillOp); @@ -3452,8 +3475,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 40e9ab31b..b61a3e145 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -454,11 +454,10 @@ public: return visitAtenAddCLikeOp(op, operands); } else if (auto scalarOp = dyn_cast(op)) { return visitBinaryScalarOp(scalarOp); - }else if (auto nllForwardOp = dyn_cast(op)) { + } else if (auto nllForwardOp = dyn_cast(op)) { return visitAtenNllLossForwardOp(nllForwardOp, operands); } - // Otherwise, this is an unknown operation. Just mark all results as // having reached a pessimistic fixpoint. return markAllPessimisticFixpoint(op->getResults());