[MLIR][TORCH] Add E2E support for ScalarImplicit, Int.Scalar op

This commit adds lowering of `aten.ScalarImplicit` and `aten.Int.Scalar` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/826/head snapshot-20220510.443
Vivek Khandelwal 2022-04-26 17:45:30 +05:30
parent 12b3af70d3
commit c69a1e5688
6 changed files with 144 additions and 0 deletions

View File

@ -6158,6 +6158,30 @@ def Torch_AtenIntFloatOp : Torch_Op<"aten.Int.float", [
}];
}
def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::Int.Scalar : (Scalar) -> (int)`";
let arguments = (ins
AnyTorchScalarType:$a
);
let results = (outs
Torch_IntType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIntScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIntScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [
AllowsTypeRefinement,
HasValueSemantics,
@ -7139,6 +7163,29 @@ def Torch_AtenCeilFloatOp : Torch_Op<"aten.ceil.float", [
let hasFolder = 1;
}
def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::ScalarImplicit : (Tensor) -> (Scalar)`";
let arguments = (ins
AnyTorchTensorType:$a
);
let results = (outs
AnyTorchScalarType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScalarImplicitOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenScalarImplicitOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -181,6 +181,20 @@ public:
};
} // namespace
namespace {
class ConvertAtenScalarImplicitOp
: public OpConversionPattern<AtenScalarImplicitOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenScalarImplicitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.a());
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::
populateTensorScalarInteropPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
@ -201,4 +215,6 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
target.addIllegalOp<PrimNumToTensorScalarOp>();
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
patterns.add<ConvertAtenScalarImplicitOp>(typeConverter, context);
target.addIllegalOp<AtenScalarImplicitOp>();
}

View File

@ -1019,6 +1019,23 @@ OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenIntScalarOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
// Constant fold float -> int conversion.
if (auto floatAttr = operands[0].dyn_cast_or_null<FloatAttr>()) {
return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
static_cast<long>(floatAttr.getValue().convertToDouble()));
}
// If the input is int type already, the op is an identity.
if (getType() == getOperand().getType())
return getOperand();
return nullptr;
}
//===----------------------------------------------------------------------===//
// NonValueTensorLiteralOp
//===----------------------------------------------------------------------===//

View File

@ -426,6 +426,9 @@ private:
ChangeResult
visitBinaryScalarOp(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitAtenScalarImplicitOp(
AtenScalarImplicitOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
};
} // namespace
@ -982,6 +985,9 @@ ChangeResult TypeAnalyzer::visitOperation(
return visitBinaryScalarOp(op, operands);
}
if (auto scalarImplicit = dyn_cast<AtenScalarImplicitOp>(op))
return visitAtenScalarImplicitOp(scalarImplicit, operands);
// Otherwise, this is an unknown operation. Just mark all results as
// having reached a pessimistic fixpoint.
return markAllPessimisticFixpoint(op->getResults());
@ -1249,6 +1255,19 @@ ChangeResult TypeAnalyzer::visitAten_SoftmaxLikeOp(
return incorporateKnowledge(op.getResult(), knowledge);
}
ChangeResult TypeAnalyzer::visitAtenScalarImplicitOp(
AtenScalarImplicitOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =
ValueKnowledge::getScalarPessimisticValueState(op.getContext());
Type dType = operands[0]->getValue().dtype;
if (dType.isa<mlir::FloatType>())
knowledge.setScalarType(Torch::FloatType::get(op->getContext()));
else if (dType.isa<mlir::IntegerType>())
knowledge.setScalarType(Torch::IntType::get(op->getContext()));
return incorporateKnowledge(op->getResult(0), knowledge);
}
// -----------------------------------------------------------------------------
// Transforms.
// -----------------------------------------------------------------------------

View File

@ -477,6 +477,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
emit("aten::Float.str : (str) -> (float)")
emit("aten::Int.float : (float) -> (int)")
emit("aten::Int.Scalar : (Scalar) -> (int)", has_folder=True)
# Primitive ops
emit("aten::__range_length : (int, int, int) -> (int)", has_folder=True)
@ -522,6 +523,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::eq.device : (Device, Device) -> (bool)")
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
emit("aten::ScalarImplicit : (Tensor) -> (Scalar)")
# backprop ops
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")

View File

@ -621,6 +621,7 @@ def EmbeddingModuleI64_basic(module, tu: TestUtils):
# ==============================================================================
class EmbeddingModuleI32(torch.nn.Module):
def __init__(self):
@ -1816,8 +1817,10 @@ class ToCopyWithDTypeFalsePinMemoryModule(torch.nn.Module):
def ToCopyWithDTypeFalsePinMemoryModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4))
# ==============================================================================
class FlipModule(torch.nn.Module):
def __init__(self):
@ -1857,3 +1860,43 @@ class DetachModule(torch.nn.Module):
module_factory=lambda: DetachModule())
def DetachModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4))
# ==============================================================================
class ScalarImplicitFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.float64, True),
])
def forward(self, x):
return float(torch.ops.aten.ScalarImplicit(x))
@register_test_case(module_factory=lambda: ScalarImplicitFloatModule())
def ScalarImplicitFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand().double())
class ScalarImplicitIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.int64, True),
])
def forward(self, x):
return int(torch.ops.aten.ScalarImplicit(x))
@register_test_case(module_factory=lambda: ScalarImplicitIntModule())
def ScalarImplicitIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100, ()))