mirror of https://github.com/llvm/torch-mlir
Add aten.fill.Scalar op lowering
The lowering of aten.fill.Scalar has been added. The changes have been made as a part of -torch-convert-to-linalg pass. Signed-off-by: Prashant Kumar <prashant@nod-labs.com>pull/446/head
parent
539511c19b
commit
36afa4a4d3
|
@ -758,3 +758,54 @@ class DropoutModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: DropoutModule())
|
||||
def DropoutModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
class Fill_TensorFloat64WithFloat32(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.fill_(tensor, 3.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32())
|
||||
def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
|
||||
class Fill_TensorFloat64WithFloat64(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.fill_(tensor, 3.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64())
|
||||
def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
||||
|
||||
|
||||
class Fill_TensorFloat64WithInt64(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.fill_(tensor, 3)
|
||||
|
||||
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64())
|
||||
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
||||
|
|
|
@ -182,6 +182,14 @@ static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
|||
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||
}
|
||||
|
||||
// Creates a tensor with required `sizes` and `elemTy` and fills it with
|
||||
// initElem.
|
||||
static Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||
Type elemTy, Value initElem) {
|
||||
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
|
||||
return b.create<linalg::FillOp>(loc, initElem, initTensor).getResult(0);
|
||||
}
|
||||
|
||||
// Helper function to caculate the output tensor dims for convolution-like ops.
|
||||
// Along each dim:
|
||||
// dim_out =
|
||||
|
@ -2782,6 +2790,33 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
namespace {
|
||||
class ConvertAtenFill_ScalarOp : public OpConversionPattern<AtenFill_ScalarOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenFill_ScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
Value self = adaptor.self();
|
||||
Value initVal = adaptor.value();
|
||||
auto tensorType = self.getType().cast<RankedTensorType>();
|
||||
|
||||
Value initValCasted = convertScalarToDtype(rewriter, loc, initVal,
|
||||
tensorType.getElementType());
|
||||
Value result =
|
||||
createInitTensor(rewriter, loc, getTensorSizes(rewriter, loc, self),
|
||||
tensorType.getElementType(), initValCasted);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
||||
public:
|
||||
|
@ -3064,6 +3099,8 @@ public:
|
|||
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenDropoutOp>();
|
||||
patterns.add<ConvertAtenDropoutOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenFill_ScalarOp>();
|
||||
patterns.add<ConvertAtenFill_ScalarOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -92,7 +92,8 @@ public:
|
|||
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||
AtenTransposeIntOp, TensorStaticInfoCastOp,
|
||||
AtenBroadcastToOp, AtenToDtypeOp, AtenContiguousOp,
|
||||
AtenPermuteOp, AtenViewOp, AtenExpandOp>(op)) {
|
||||
AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
||||
AtenFill_ScalarOp>(op)) {
|
||||
// AtenContiguousOp might return a view, so this is conservatively
|
||||
// correct. We could potentially be more precise and identify the cases
|
||||
// that it does not return a view and treat those as having value
|
||||
|
|
Loading…
Reference in New Issue