[TORCH][MLIR] Add and templatize lowering of [`aten.zeros|aten.ones|aten.empty`] ops

- Templatize `aten.zeros` and `aten.ones` ops lowering.
- Add E2E support for `aten.empty` op.
- Add Integer type support in `aten.mul.Scalar` op lowering.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/480/head
Gaurav Shukla 2021-12-10 20:25:47 +05:30 committed by Gaurav Shukla
parent 528354de84
commit 8d4879feb0
3 changed files with 168 additions and 77 deletions

View File

@ -637,6 +637,57 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
# ==============================================================================
class EmptyIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return 0 * torch.empty((3, 4), dtype=torch.int64)
@register_test_case(module_factory=lambda: EmptyIntModule())
def EmptyModule_int(module, tu: TestUtils):
module.forward()
# ==============================================================================
class EmptyFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.abs(torch.empty((3, 4), dtype=torch.float32)) > -1.0
@register_test_case(module_factory=lambda: EmptyFloatModule())
def EmptyModule_float(module, tu: TestUtils):
module.forward()
class EmptyFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.abs(torch.empty((3, 4), dtype=torch.float32,
pin_memory=False)) > -1.0
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
def EmptyModule_falsePinMemory(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ContiguousModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -443,8 +443,24 @@ def RsubModule_noalpha_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseMulScalarIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
class ElementwiseMulScalarModule(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.mul(x, 4)
@register_test_case(module_factory=lambda: ElementwiseMulScalarIntModule())
def ElementwiseMulScalarModule_int(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 4)))
class ElementwiseMulScalarFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -456,11 +472,27 @@ class ElementwiseMulScalarModule(torch.nn.Module):
def forward(self, x):
return torch.mul(x, 100.0)
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: ElementwiseMulScalarFloatModule())
def ElementwiseMulScalarModule_float(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class ElementwiseMulScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.mul(x, 8.0)
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 4)))
class ElementwiseMulTensorFloatModule(torch.nn.Module):
def __init__(self):

View File

@ -189,6 +189,21 @@ static Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
return b.create<linalg::FillOp>(loc, initElem, initTensor).getResult(0);
}
// Creates a constant of type `elemType` with value `val`.
static Value getConstant(OpBuilder &b, Location loc, int64_t val,
Type elemType) {
Attribute attr = {};
if (elemType.isa<mlir::FloatType>())
attr = b.getFloatAttr(elemType, val);
if (elemType.isa<mlir::IndexType>())
attr = b.getIndexAttr(val);
if (elemType.isa<mlir::IntegerType>())
attr = b.getIntegerAttr(
elemType, APInt(elemType.cast<IntegerType>().getWidth(), val));
if (!attr)
return nullptr;
return b.create<arith::ConstantOp>(loc, elemType, attr);
}
// Helper function to caculate the output tensor dims for convolution-like ops.
// Along each dim:
@ -1828,13 +1843,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(mulScalar.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
mulScalar.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value self = payloadArgs[0];
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<arith::MulFOp>(loc, self, other);
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
if (dtype.isa<mlir::FloatType>())
return b.create<arith::MulFOp>(loc, lhs, rhs);
if (dtype.isa<mlir::IntegerType>())
return b.create<arith::MulIOp>(loc, lhs, rhs);
mulScalar.emitError("unimplemented: Only integer/float dtype supported");
return nullptr;
}
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
Value input = payloadArgs[0];
@ -3417,83 +3433,68 @@ public:
} // namespace
namespace {
// Converts AtenOnesOp and AtenZerosOp.
struct ConvertAtenOnesZerosOp : ConversionPattern {
ConvertAtenOnesZerosOp(TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
// Converts constant tensor allocation like ops.
template <typename OpTy>
class ConvertConstantTensorAllocOp : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenOnesOp, AtenZerosOp>(op))
return rewriter.notifyMatchFailure(op,
"not a supported ones or zeros op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value size, layout, pin_memory;
int64_t elementValue;
if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
size = onesOp.size();
layout = onesOp.layout();
pin_memory = onesOp.pin_memory();
elementValue = 1;
} else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
size = zerosOp.size();
layout = zerosOp.layout();
pin_memory = zerosOp.pin_memory();
elementValue = 0;
}
// We ignore device, but add simple asserts for unimplemented kwargs
if (!layout.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op,
"only default layout is supported");
bool pinMemory = false;
if (!pin_memory.getType().isa<Torch::NoneType>() &&
!matchPattern(pin_memory, m_TorchConstantBool(&pinMemory))) {
// Currently memory pinning and layout features are not supported.
if (!op.layout().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "pin_memory must be constant bool or None");
}
if (pinMemory)
return rewriter.notifyMatchFailure(op, "memory pinning not supported");
SmallVector<Value> sizes, sizeIndex;
if (!getListConstructElements(size, sizes)) {
op, "unimplemented: only default layout is supported");
bool pinMemory;
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) {
return rewriter.notifyMatchFailure(
op, "size must be created by ListConstruct");
op, "unimplemented: pin_memory must be either None or false");
}
sizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), sizes);
for (size_t i = 0; i < sizes.size(); i++)
sizeIndex.push_back(castIntToIndex(rewriter, loc, sizes[i]));
RankedTensorType newResultType =
getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type outElementType = newResultType.getElementType();
// Memory formats are not supported in the case of `AtenEmptyMemoryFormat`.
if constexpr (std::is_same<OpTy, AtenEmptyMemoryFormatOp>::value) {
if (!op.memory_format().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default memory format is supported");
}
Value constantOp = rewriter.create<arith::ConstantOp>(
loc, outElementType,
(outElementType.isa<mlir::FloatType>()
? rewriter.getFloatAttr(outElementType, elementValue)
.cast<mlir::Attribute>()
: rewriter.getIntegerAttr(outElementType, elementValue)
.cast<mlir::Attribute>()));
Value outTensor = rewriter
.create<linalg::InitTensorOp>(
loc, sizeIndex, newResultType.getElementType())
.getResult();
Value fillOp = rewriter.create<linalg::FillOp>(loc, constantOp, outTensor)
.getResult(0);
Location loc = op.getLoc();
TypeConverter *typeConverter = this->getTypeConverter();
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
if (!getListConstructElements(op.size(), resultSizeTorchInt)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: size must be constructed using ListConstruct");
}
resultSize = getTypeConvertedValues(rewriter, loc, typeConverter,
resultSizeTorchInt);
for (auto size : resultSize)
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, fillOp);
auto resultType =
typeConverter->convertType(op.getType()).template cast<RankedTensorType>();
Type outElemType = resultType.getElementType();
// Create an uninitialized tensor of `resultSize` shape. It will be returned
// without initialization/filling in the case of `AtenEmptyMemoryFormatOp`.
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultSizeIndex, outElemType);
// `AtenZeros` and `AtenOnes` ops will be filled with corresponding values.
if (std::is_same<OpTy, AtenZerosOp>::value) {
Value zero = getConstant(rewriter, loc, 0, outElemType);
outputTensor =
rewriter.create<linalg::FillOp>(loc, zero, outputTensor).getResult(0);
} else if (std::is_same<OpTy, AtenOnesOp>::value) {
Value one = getConstant(rewriter, loc, 1, outElemType);
outputTensor =
rewriter.create<linalg::FillOp>(loc, one, outputTensor).getResult(0);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
return success();
}
};
@ -3704,8 +3705,15 @@ public:
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
target.addIllegalOp<AtenEmbeddingOp>();
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
target.addIllegalOp<AtenOnesOp, AtenZerosOp>();
patterns.add<ConvertAtenOnesZerosOp>(typeConverter, context);
target.addIllegalOp<AtenEmptyMemoryFormatOp>();
patterns.add<ConvertConstantTensorAllocOp<AtenEmptyMemoryFormatOp>>(
typeConverter, context);
target.addIllegalOp<AtenZerosOp>();
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp>>(typeConverter,
context);
target.addIllegalOp<AtenOnesOp>();
patterns.add<ConvertConstantTensorAllocOp<AtenOnesOp>>(typeConverter,
context);
target.addIllegalOp<AtenContiguousOp>();
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<AtenIntTensorOp>();