mirror of https://github.com/llvm/torch-mlir
Add contiguous
parent
eb6996d557
commit
2ea2ab518b
|
@ -446,6 +446,7 @@ class BroadcastToModule(torch.nn.Module):
|
|||
def BroadcastToModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 1))
|
||||
|
||||
|
||||
class OnesModuleInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -475,3 +476,20 @@ class OnesModuleFloat(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: OnesModuleFloat())
|
||||
def OnesModuleFloat_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
class ContiguousModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.contiguous()
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ContiguousModule())
|
||||
def ContiguousModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1))
|
||||
|
|
|
@ -2555,6 +2555,23 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenContiguousOp : public OpConversionPattern<AtenContiguousOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenContiguousOp op, llvm::ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
AtenContiguousOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOp(op, adaptor.self());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenOnesOp : public OpConversionPattern<AtenOnesOp> {
|
||||
public:
|
||||
|
@ -2685,6 +2702,8 @@ public:
|
|||
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenOnesOp>();
|
||||
patterns.add<ConvertAtenOnesOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenContiguousOp>();
|
||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -90,8 +90,12 @@ public:
|
|||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||
AtenTransposeIntOp, AtenPermuteOp, TensorStaticInfoCastOp,
|
||||
AtenBroadcastToOp>(op)) {
|
||||
AtenTransposeIntOp, TensorStaticInfoCastOp,
|
||||
AtenBroadcastToOp, AtenContiguousOp, AtenPermuteOp>(op)) {
|
||||
// AtenContiguousOp might return a view, so this is conservatively
|
||||
// correct. We could potentially be more precise and identify the cases
|
||||
// that it does not return a view and treat those as having value
|
||||
// semantics.
|
||||
viewLikeOps.push_back(op);
|
||||
llvm::append_range(workList, op->getResult(0).getUsers());
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue