Fix type promotion code for scalar only operations

Fix the type promotion code for scalar only operation to return
TorchType which is the type tracked in ValueKnowledge.scalarType.

- Fix `getPromotedResultScalarType` to return Torch type.
- Add `getBuiltInTypeForTorchScalar` helper to convert scalar type
to builtin type before passing to the next level type promotion
helper `updateResultTypeState`.
- Add `setScalarType` helper to make setting ValueKnowledge.scalarType
  easier.
pull/843/head snapshot-20220507.437
Yi Zhang 2022-05-05 21:35:34 -04:00
parent b20679e1b8
commit 28be6511d2
7 changed files with 122 additions and 28 deletions

View File

@ -5650,6 +5650,30 @@ def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [
}];
}
def Torch_AtenAddOp : Torch_Op<"aten.add", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::add : (Scalar, Scalar) -> (Scalar)`";
let arguments = (ins
AnyTorchScalarType:$a,
AnyTorchScalarType:$b
);
let results = (outs
AnyTorchScalarType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAddOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenAddOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -30,11 +30,18 @@ torch_upstream::ScalarType getScalarTypeForType(Type type);
Type getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
Type getTorchTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt);
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype);
// Helper to convert a tensor to a specific scalar type.
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
Type dtype);
bool isBuiltInType(Type type);
// Helper funtion to get rank of `Base tensor type`.
// -1 is returned if the tensorRank can't be determined.
int getTensorRank(Value tensor);

View File

@ -108,8 +108,9 @@ public:
result.returnOp = returnOp;
} else {
return rewriter.notifyMatchFailure(
copyToNonValueTensor,
"unsupported op encountered during abstract analysis");
copyToNonValueTensor, "unsupported op `" +
user->getName().getStringRef() +
"` encountered during abstract analysis");
}
}
return result;

View File

@ -114,6 +114,10 @@ static torch_upstream::TypeKind getTypeKind(Type type) {
/// on the `dtype` from tensors and can't be used on other types like scalar
/// types.
static Optional<Type> meetElementTypes(Type lhs, Type rhs) {
auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); };
assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type");
assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type");
if (!lhs)
return rhs;
if (!rhs)
@ -167,6 +171,14 @@ struct ValueKnowledge {
: dtype(dtype), scalarType(scalarType), kind(kind),
optional(optionalKnowledge) {}
void setScalarType(Type type) {
bool isValidScalarType = type.isa<NumberType, IntType, Torch::FloatType>();
assert(isValidScalarType &&
"scalarType can only be one of NumberType, IntType and FloatType");
scalarType = type;
kind = getTypeKind(type);
}
// Get the static knowledge intrinsic to `type`.
static ValueKnowledge getKnowledgeFromType(Type type) {
ValueKnowledge result = getPessimisticValueState(type.getContext());
@ -420,7 +432,7 @@ private:
// This is the type rule used for deciding dtype for:
// 1. A new tensor created from given data.
// 2. The scalar type for type promotion when a scalar is an operand of a tensor
// and scalar binary operation.
// operation (such as AtenMulScalarOp, AtenAddScalarOp etc)
// If the data is floating-point, the `dtype` is inferred to be the
// default dtype, see `torch.get_default_dtype`.
static Type getDefaultDtypeForTorchScalar(Type type) {
@ -438,12 +450,29 @@ static Type getDefaultDtypeForTorchScalar(Type type) {
"getDefaultDtypeForTorchScalar called on an unsupported type");
}
// This is the type rule used for deciding builtin type for:
// 1. The dtype of the result tensor when converting a Scalar into a Tensor like
// PrimNumToTensorScalarOp.
// 2. The scalar type for type promotion when a scalar is an operand of scalar
// only operation like AtenAddOp.
static Type getBuiltInTypeForTorchScalar(Type type) {
MLIRContext *context = type.getContext();
if (type.isa<Torch::FloatType>())
return Float64Type::get(context);
if (type.isa<Torch::IntType>())
return IntegerType::get(context, 64, IntegerType::Signed);
if (type.isa<Torch::BoolType>())
return IntegerType::get(context, 1);
llvm_unreachable(
"getBuiltInTypeForTorchScalar called on an unsupported type");
}
static torch_upstream::ResultTypeState
updateResultTypeState(Type scalarType,
const torch_upstream::ResultTypeState &inState) {
assert(isBuiltInType(scalarType) && "scalarType must be builtin type");
torch_upstream::ResultTypeState new_state = inState;
torch_upstream::ScalarType current =
getScalarTypeForType(getDefaultDtypeForTorchScalar(scalarType));
torch_upstream::ScalarType current = getScalarTypeForType(scalarType);
new_state.wrappedResult =
promote_skip_undefined(inState.wrappedResult, current);
return new_state;
@ -481,15 +510,22 @@ updateResultTypeState(ValueKnowledge *tensor, Optional<bool> rankIsNonZero,
return new_state;
}
static Type getPromotedResultType(ArrayRef<Type> scalarTypes) {
// Type promotion helper for operators where only scalar operands participating
// in type promotion like AtenAddOp.
//
// \return The return type is a TorchType.
static Type getPromotedResultScalarType(ArrayRef<Type> scalarTypes) {
torch_upstream::ResultTypeState state = {};
for (const Type &scalarType : scalarTypes)
state = updateResultTypeState(scalarType, state);
return getTypeForScalarType(scalarTypes[0].getContext(), result_type(state));
for (const Type &scalarType : scalarTypes) {
state =
updateResultTypeState(getBuiltInTypeForTorchScalar(scalarType), state);
}
return getTorchTypeForScalarType(scalarTypes[0].getContext(),
result_type(state));
}
// Returns most generic type Type() if the tensor dtype is unknown.
static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) {
static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
if (!tensor->dtype)
return Type();
torch_upstream::ResultTypeState state = {};
@ -497,7 +533,8 @@ static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) {
// wrappedResult which is a lower priority than both dimResult and zeroResult.
state = updateResultTypeState(tensor, /*rankIsNonZero=*/None, state,
/*skipRankCheck=*/true);
state = updateResultTypeState(scalarType, state);
state =
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
return getTypeForScalarType(scalarType.getContext(), result_type(state));
}
@ -550,8 +587,7 @@ getPromotedResultTypeAssumingNonZeroRank(MLIRContext *context,
static void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge,
Value dtype,
Type inputDType) {
assert(isa<BuiltinDialect>(inputDType.getDialect()) &&
"`inputDType` must be a builtin type");
assert(isBuiltInType(inputDType) && "`inputDType` must be a builtin type");
int64_t dtypeInt;
if (dtype.getType().isa<Torch::NoneType>())
knowledge.dtype = inputDType;
@ -692,7 +728,7 @@ ChangeResult TypeAnalyzer::visitOperation(
Value scalar = op->getOperand(1);
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(getContext());
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType());
return incorporateKnowledge(op->getResult(0), knowledge);
}
@ -712,8 +748,8 @@ ChangeResult TypeAnalyzer::visitOperation(
Value rhsScalar = op->getOperand(2);
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(getContext());
knowledge.dtype =
getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()});
knowledge.dtype = getDefaultDtypeForTorchScalar(getPromotedResultScalarType(
{lhsScalar.getType(), rhsScalar.getType()}));
return incorporateKnowledge(op->getResult(0), knowledge);
}
@ -723,7 +759,7 @@ ChangeResult TypeAnalyzer::visitOperation(
Value scalar = op->getOperand(2);
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(getContext());
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType());
return incorporateKnowledge(op->getResult(0), knowledge);
}
@ -733,7 +769,7 @@ ChangeResult TypeAnalyzer::visitOperation(
Value scalar = op->getOperand(1);
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(getContext());
knowledge.dtype = getPromotedResultType(&rhs, scalar.getType());
knowledge.dtype = getPromotedResultDType(&rhs, scalar.getType());
return incorporateKnowledge(op->getResult(0), knowledge);
}
@ -942,7 +978,7 @@ ChangeResult TypeAnalyzer::visitOperation(
return visitNumToTensorOp(numToTensorOp);
}
if (isa<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>(op)) {
if (isa<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp, AtenAddOp>(op)) {
return visitBinaryScalarOp(op, operands);
}
@ -1069,10 +1105,9 @@ ChangeResult TypeAnalyzer::visitBinaryScalarOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =
ValueKnowledge::getScalarPessimisticValueState(op->getContext());
Type resultType = getPromotedResultType(
Type resultType = getPromotedResultScalarType(
{op->getOperand(0).getType(), op->getOperand(1).getType()});
knowledge.scalarType = resultType;
knowledge.kind = getTypeKind(resultType);
knowledge.setScalarType(resultType);
return incorporateKnowledge(op->getResult(0), knowledge);
}
@ -1183,12 +1218,7 @@ ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h.
// `NumToTensor` falls in the latter case.
Type type = op.a().getType();
if (type.isa<Torch::FloatType>())
knowledge.dtype = Float64Type::get(op.getContext());
else if (type.isa<Torch::IntType>())
knowledge.dtype =
IntegerType::get(op.getContext(), 64, IntegerType::Signed);
knowledge.dtype = getBuiltInTypeForTorchScalar(type);
return incorporateKnowledge(op.getResult(), knowledge);
}

View File

@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "mlir/IR/BuiltinDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
using namespace mlir;
@ -78,6 +79,19 @@ Type Torch::getTypeForScalarType(
}
}
Type Torch::getTorchTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt) {
switch (dtypeInt) {
case torch_upstream::ScalarType::Double:
return Torch::FloatType::get(context);
case torch_upstream::ScalarType::Long:
return Torch::IntType::get(context);
default:
llvm::report_fatal_error(
"Unsupported scalar type to Torch type conversion");
}
}
Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype) {
int intType = (int)getScalarTypeForType(dtype);
@ -100,6 +114,10 @@ Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
return converted;
}
bool Torch::isBuiltInType(Type type) {
return isa<BuiltinDialect>(type.getDialect());
}
int Torch::getTensorRank(Value tensor) {
int tensorRank = -1;
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();

View File

@ -518,6 +518,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
emit("aten::div : (Scalar, Scalar) -> (float)")
emit("aten::add : (Scalar, Scalar) -> (Scalar)")
emit("aten::eq.device : (Device, Device) -> (bool)")
emit("aten::ceil.float : (float) -> (int)", has_folder=True)

View File

@ -127,6 +127,18 @@ func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.v
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func @type_promotion_scalar_operation(
// CHECK-SAME: %[[FLOAT:.*]]: !torch.float,
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
// CHECK: %[[ADD:.*]] = torch.aten.add %[[FLOAT]], %[[INT]] : !torch.float, !torch.int -> !torch.float
// CHECK: %[[RET:.*]] = torch.derefine %[[ADD]] : !torch.float to !torch.number
// CHECK: return %[[RET]] : !torch.number
func @type_promotion_scalar_operation(%float: !torch.float, %int: !torch.int) -> !torch.number {
%ret = torch.aten.add %float, %int : !torch.float, !torch.int -> !torch.number
return %ret : !torch.number
}
// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,