mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.zeros op
This commit adds lowering of `aten.zeros` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/465/head
parent
977b1b03ea
commit
9958cf08b6
|
@ -900,7 +900,6 @@ class NumelModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.numel(input)
|
||||
|
||||
|
@ -918,10 +917,89 @@ class NumelZeroRankModule(torch.nn.Module):
|
|||
None,
|
||||
([], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.numel(input)
|
||||
|
||||
@register_test_case(module_factory=lambda: NumelZeroRankModule())
|
||||
def NumelZeroRankModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10,[]))
|
||||
|
||||
|
||||
class ZerosModuleInt2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.zeros(3, 4, dtype=torch.int64)
|
||||
|
||||
@register_test_case(module_factory=lambda: ZerosModuleInt2D())
|
||||
def ZerosModuleInt2D_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ZerosModuleInt3D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.zeros(3, 4, 5, dtype=torch.int64)
|
||||
|
||||
@register_test_case(module_factory=lambda: ZerosModuleInt3D())
|
||||
def ZerosModuleInt3D_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ZerosModuleFloat2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.zeros(3, 4, dtype=torch.float32)
|
||||
|
||||
@register_test_case(module_factory=lambda: ZerosModuleFloat2D())
|
||||
def ZerosModuleFloat2D_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ZerosModuleFloat3D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.zeros(3, 4, 5, dtype=torch.float32)
|
||||
|
||||
@register_test_case(module_factory=lambda: ZerosModuleFloat3D())
|
||||
def ZerosModuleFloat3D_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ZerosModuleFalsePinMemory(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
|
||||
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
|
|
@ -3269,24 +3269,44 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenOnesOp : public OpConversionPattern<AtenOnesOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
// Converts AtenOnesOp and AtenZerosOp.
|
||||
struct ConvertAtenOnesZerosOp : ConversionPattern {
|
||||
ConvertAtenOnesZerosOp(TypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOnesOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenOnesOp, AtenZerosOp>(op))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"not a supported ones or zeros op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
Location loc = op->getLoc();
|
||||
|
||||
SmallVector<Value, 3> opArguments;
|
||||
int64_t elementValue;
|
||||
|
||||
if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
|
||||
opArguments.insert(opArguments.end(),
|
||||
{onesOp.size(), onesOp.layout(), onesOp.pin_memory()});
|
||||
elementValue = 1;
|
||||
} else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
|
||||
opArguments.insert(opArguments.end(), {zerosOp.size(), zerosOp.layout(),
|
||||
zerosOp.pin_memory()});
|
||||
elementValue = 0;
|
||||
}
|
||||
|
||||
// We ignore device, but add simple asserts for unimplemented kwargs
|
||||
if (!op.layout().getType().isa<Torch::NoneType>())
|
||||
if (!opArguments[1].getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only default layout is supported");
|
||||
|
||||
bool pinMemory = false;
|
||||
if (!op.pin_memory().getType().isa<Torch::NoneType>() &&
|
||||
!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory))) {
|
||||
if (!opArguments[2].getType().isa<Torch::NoneType>() &&
|
||||
!matchPattern(opArguments[2], m_TorchConstantBool(&pinMemory))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "pin_memory must be constant bool or None");
|
||||
}
|
||||
|
@ -3294,7 +3314,7 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "memory pinning not supported");
|
||||
|
||||
SmallVector<Value> size, sizeIndex;
|
||||
if (!getListConstructElements(op.size(), size)) {
|
||||
if (!getListConstructElements(opArguments[0], size)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "size must be created by ListConstruct");
|
||||
}
|
||||
|
@ -3303,21 +3323,24 @@ public:
|
|||
sizeIndex.push_back(castIntToIndex(rewriter, loc, size[i]));
|
||||
|
||||
RankedTensorType newResultType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Type outElementType = newResultType.getElementType();
|
||||
|
||||
Value one = rewriter.create<arith::ConstantOp>(
|
||||
Value constantOp = rewriter.create<arith::ConstantOp>(
|
||||
loc, outElementType,
|
||||
(outElementType.isa<mlir::FloatType>()
|
||||
? rewriter.getFloatAttr(outElementType, 1).cast<mlir::Attribute>()
|
||||
: rewriter.getIntegerAttr(outElementType, 1)
|
||||
? rewriter.getFloatAttr(outElementType, elementValue)
|
||||
.cast<mlir::Attribute>()
|
||||
: rewriter.getIntegerAttr(outElementType, elementValue)
|
||||
.cast<mlir::Attribute>()));
|
||||
Value outTensor = rewriter
|
||||
.create<linalg::InitTensorOp>(
|
||||
loc, sizeIndex, newResultType.getElementType())
|
||||
.getResult();
|
||||
Value fillOp =
|
||||
rewriter.create<linalg::FillOp>(loc, one, outTensor).getResult(0);
|
||||
Value fillOp = rewriter.create<linalg::FillOp>(loc, constantOp, outTensor)
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, fillOp);
|
||||
|
||||
|
@ -3452,8 +3475,8 @@ public:
|
|||
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenEmbeddingOp>();
|
||||
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenOnesOp>();
|
||||
patterns.add<ConvertAtenOnesOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenOnesOp, AtenZerosOp>();
|
||||
patterns.add<ConvertAtenOnesZerosOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenContiguousOp>();
|
||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIntTensorOp>();
|
||||
|
|
|
@ -454,11 +454,10 @@ public:
|
|||
return visitAtenAddCLikeOp(op, operands);
|
||||
} else if (auto scalarOp = dyn_cast<AtenAddIntOp>(op)) {
|
||||
return visitBinaryScalarOp(scalarOp);
|
||||
}else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
|
||||
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
|
||||
return visitAtenNllLossForwardOp(nllForwardOp, operands);
|
||||
}
|
||||
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
// having reached a pessimistic fixpoint.
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
|
|
Loading…
Reference in New Issue