mirror of https://github.com/llvm/torch-mlir
Add convertScalarToDtype helper.
This is to facilitate scalar type conversion in the TorchToLinalg. As part of adding the helper, this PR also: - Updated `AtenAddTensorOp`, `AtenSubTensorOp` to use the helpers to support more type variants. - Added e2e type promotion testing. - Added i32 memref return/arg type to support e2e testing.pull/402/head
parent
e23cabf3a9
commit
05c4dd8e39
|
@ -33,6 +33,8 @@ from . import conv
|
||||||
from . import batchnorm
|
from . import batchnorm
|
||||||
from . import quantized_models
|
from . import quantized_models
|
||||||
from . import elementwise
|
from . import elementwise
|
||||||
|
from . import type_promotion
|
||||||
|
from . import type_conversion
|
||||||
from . import reduction
|
from . import reduction
|
||||||
from . import argmax
|
from . import argmax
|
||||||
from . import matmul
|
from . import matmul
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class TypeConversionF32ToF64Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True)
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return x.to(torch.float64)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TypeConversionF32ToF64Module())
|
||||||
|
def TypeConversionF32ToF64Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class TypeConversionF64ToF32Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float64, True)
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return x.to(torch.float32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TypeConversionF64ToF32Module())
|
||||||
|
def TypeConversionF64ToF32Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 5).type(torch.float64))
|
||||||
|
|
||||||
|
class TypeConversionI32ToI64Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True)
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return x.to(torch.int64)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TypeConversionI32ToI64Module())
|
||||||
|
def TypeConversionI32ToI64Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(5, [2, 3]).type(torch.int32))
|
||||||
|
|
||||||
|
class TypeConversionI64ToI32Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True)
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return x.to(torch.int32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TypeConversionI64ToI32Module())
|
||||||
|
def TypeConversionI64ToI32Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(5, [2, 3]))
|
|
@ -0,0 +1,113 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.add(a, b, alpha=3)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule())
|
||||||
|
def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
torch.randint(10, [4]).type(torch.int32),
|
||||||
|
torch.randint(10, [4]))
|
||||||
|
|
||||||
|
|
||||||
|
class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.add(a, b, alpha=3)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: TypePromotionDifferentCategoryModule())
|
||||||
|
def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, [4]), torch.randn(4))
|
||||||
|
|
||||||
|
|
||||||
|
class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.add(a, b, alpha=2.3)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: TypePromotionSameCategoryZeroRankWiderModule())
|
||||||
|
def TypePromotionSameCategoryZeroRankWider_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4), tu.rand().type(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
|
class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.add(a, b, alpha=2)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: TypePromotionZeroRankHigherCategoryModule())
|
||||||
|
def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, [4]), tu.rand())
|
||||||
|
|
||||||
|
|
||||||
|
class TypePromotionAlphaWiderModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.add(a, b, alpha=2.3)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule())
|
||||||
|
def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4), tu.rand())
|
|
@ -1251,29 +1251,65 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Convert a scalar value to the target type. The scalar value can be an element
|
||||||
|
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||||
|
// should be converted builtin types.
|
||||||
|
static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||||
|
Type dtype) {
|
||||||
|
Type scalarType = scalar.getType();
|
||||||
|
if (scalarType == dtype)
|
||||||
|
return scalar;
|
||||||
|
|
||||||
static Value promoteScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype) {
|
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
|
||||||
// TODO: For the integer case, we probably need the unconverted dtype to
|
|
||||||
// be able to know if we need signed or unsigned conversion.
|
// be able to know if we need signed or unsigned conversion.
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
auto isByteOrChar = [](Type type) {
|
||||||
if (scalar.getType().isa<mlir::FloatType>()) {
|
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
|
||||||
// `scalar` will always be f64 since that is what the TypeConverter
|
return integerTy.getWidth() == 8;
|
||||||
// converts !torch.float to.
|
|
||||||
return b.create<arith::TruncFOp>(loc, scalar, dtype);
|
|
||||||
} else {
|
|
||||||
assert(scalar.getType().isa<mlir::IntegerType>());
|
|
||||||
// `scalar` will always be i64 since that is what the TypeConverter
|
|
||||||
// converts !torch.int to.
|
|
||||||
return b.create<arith::SIToFPOp>(loc, scalar, dtype);
|
|
||||||
}
|
}
|
||||||
}
|
return false;
|
||||||
mlir::emitError(loc) << "promoteScalarToDtype for dtype " << dtype;
|
};
|
||||||
|
|
||||||
|
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
|
||||||
|
scalarType.isSignlessInteger(1) || dtype.isSignlessInteger(1)) {
|
||||||
|
// TODO: Handle bool type.
|
||||||
|
mlir::emitError(loc)
|
||||||
|
<< "unsupported byte, char or bool type for convertScalarToDtype "
|
||||||
|
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
||||||
|
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
||||||
|
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
||||||
|
return b.create<arith::TruncFOp>(loc, scalar, dtype);
|
||||||
|
// Only scalarFloat width < dtypeFloat width can reach here.
|
||||||
|
return b.create<arith::ExtFOp>(loc, scalar, dtype);
|
||||||
|
}
|
||||||
|
assert(scalarType.isa<mlir::IntegerType>());
|
||||||
|
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
|
||||||
|
// unsigned handling is needed, and we checked for that case above.
|
||||||
|
return b.create<arith::SIToFPOp>(loc, scalar, dtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||||
|
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
|
||||||
|
return b.create<arith::FPToSIOp>(loc, scalar, dtype);
|
||||||
|
assert(scalarType.isa<mlir::IntegerType>());
|
||||||
|
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
|
||||||
|
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
||||||
|
return b.create<arith::TruncIOp>(loc, scalar, dtype);
|
||||||
|
// Only scalarInteger width < dtypeInteger width can reach here.
|
||||||
|
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
|
||||||
|
// unsigned handling is needed, and we checked for that case above.
|
||||||
|
return b.create<arith::ExtSIOp>(loc, scalar, dtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||||
|
}
|
||||||
|
|
||||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
|
OpBuilder &b, Location loc, TypeConverter *converter,
|
||||||
ArrayRef<Value> operands) {
|
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||||
if (isa<AtenTanhOp>(op))
|
if (isa<AtenTanhOp>(op))
|
||||||
return b.create<math::TanhOp>(loc, payloadArgs[0]);
|
return b.create<math::TanhOp>(loc, payloadArgs[0]);
|
||||||
if (isa<AtenExpOp>(op))
|
if (isa<AtenExpOp>(op))
|
||||||
|
@ -1322,40 +1358,35 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
||||||
AtenAddTensorOp::Adaptor adaptor(operands);
|
AtenAddTensorOp::Adaptor adaptor(operands);
|
||||||
if (add.alpha().getType().isa<Torch::FloatType>()) {
|
Type dtype = converter->convertType(add.getType())
|
||||||
add.emitError("unimplemented: !torch.float 'alpha'");
|
.cast<RankedTensorType>()
|
||||||
return nullptr;
|
.getElementType();
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype);
|
||||||
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
|
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||||
|
return b.create<arith::AddFOp>(loc, lhs, scaled);
|
||||||
|
} else {
|
||||||
|
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
|
||||||
|
return b.create<arith::AddIOp>(loc, lhs, scaled);
|
||||||
}
|
}
|
||||||
if (!add.getType()
|
|
||||||
.cast<ValueTensorType>()
|
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
add.emitError("unimplemented: non-floating point dtype");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
Value alphaFloat = b.create<arith::SIToFPOp>(loc, payloadArgs[0].getType(),
|
|
||||||
adaptor.alpha());
|
|
||||||
Value scaled = b.create<arith::MulFOp>(loc, payloadArgs[1], alphaFloat);
|
|
||||||
return b.create<arith::AddFOp>(loc, payloadArgs[0], scaled);
|
|
||||||
}
|
}
|
||||||
if (auto sub = dyn_cast<AtenSubTensorOp>(op)) {
|
if (auto sub = dyn_cast<AtenSubTensorOp>(op)) {
|
||||||
AtenSubTensorOp::Adaptor adaptor(operands);
|
AtenSubTensorOp::Adaptor adaptor(operands);
|
||||||
if (sub.alpha().getType().isa<Torch::FloatType>()) {
|
Type dtype = converter->convertType(sub.getType())
|
||||||
sub.emitError("unimplemented: !torch.float 'alpha'");
|
.cast<RankedTensorType>()
|
||||||
return nullptr;
|
.getElementType();
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype);
|
||||||
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
|
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||||
|
return b.create<arith::SubFOp>(loc, lhs, scaled);
|
||||||
|
} else {
|
||||||
|
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
|
||||||
|
return b.create<arith::SubIOp>(loc, lhs, scaled);
|
||||||
}
|
}
|
||||||
if (!sub.getType()
|
|
||||||
.cast<ValueTensorType>()
|
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
sub.emitError("unimplemented: non-floating point dtype");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
Value alphaFloat = b.create<arith::SIToFPOp>(loc, payloadArgs[0].getType(),
|
|
||||||
adaptor.alpha());
|
|
||||||
Value scaled = b.create<arith::MulFOp>(loc, payloadArgs[1], alphaFloat);
|
|
||||||
|
|
||||||
return b.create<arith::SubFOp>(loc, payloadArgs[0], scaled);
|
|
||||||
}
|
}
|
||||||
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
|
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
|
||||||
if (!mul.getType()
|
if (!mul.getType()
|
||||||
|
@ -1386,7 +1417,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
Type dtype = pow.self().getType().cast<ValueTensorType>().getDtype();
|
Type dtype = pow.self().getType().cast<ValueTensorType>().getDtype();
|
||||||
Value expPromoted = promoteScalarToDtype(b, loc, operands[1], dtype);
|
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||||
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
|
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
|
||||||
}
|
}
|
||||||
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
|
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
|
||||||
|
@ -1430,7 +1461,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
||||||
}
|
}
|
||||||
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
|
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
|
||||||
auto dtype = clamp.getType().cast<ValueTensorType>().getDtype();
|
Type dtype = converter->convertType(clamp.getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::FloatType>()) {
|
if (!dtype.isa<mlir::FloatType>()) {
|
||||||
clamp.emitError("unimplemented: non-floating point dtype");
|
clamp.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1445,13 +1478,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
auto result = payloadArgs[0];
|
auto result = payloadArgs[0];
|
||||||
if (!min.getType().isa<Torch::NoneType>()) {
|
if (!min.getType().isa<Torch::NoneType>()) {
|
||||||
auto minPromoted = promoteScalarToDtype(b, loc, min, dtype);
|
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
|
||||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||||
result, minPromoted);
|
result, minPromoted);
|
||||||
result = b.create<SelectOp>(loc, pred, minPromoted, result);
|
result = b.create<SelectOp>(loc, pred, minPromoted, result);
|
||||||
}
|
}
|
||||||
if (!max.getType().isa<Torch::NoneType>()) {
|
if (!max.getType().isa<Torch::NoneType>()) {
|
||||||
auto maxPromoted = promoteScalarToDtype(b, loc, max, dtype);
|
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
||||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||||
result, maxPromoted);
|
result, maxPromoted);
|
||||||
result = b.create<SelectOp>(loc, pred, maxPromoted, result);
|
result = b.create<SelectOp>(loc, pred, maxPromoted, result);
|
||||||
|
@ -1459,36 +1492,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
|
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
|
||||||
if (!rsub.getType()
|
Type dtype = converter->convertType(rsub.getType())
|
||||||
.cast<ValueTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getDtype()
|
.getElementType();
|
||||||
.isa<mlir::FloatType>()) {
|
if (!dtype.isa<mlir::FloatType>()) {
|
||||||
rsub.emitError("unimplemented: non-floating point dtype");
|
rsub.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
Value self = payloadArgs[0];
|
Value self = payloadArgs[0];
|
||||||
Value other = promoteScalarToDtype(b, loc, operands[1], self.getType());
|
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||||
Value alpha = promoteScalarToDtype(b, loc, operands[2], self.getType());
|
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
|
||||||
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
|
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
|
||||||
return b.create<arith::SubFOp>(loc, other, mult);
|
return b.create<arith::SubFOp>(loc, other, mult);
|
||||||
}
|
}
|
||||||
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||||
Value input = payloadArgs[0];
|
Value input = payloadArgs[0];
|
||||||
Type inType = input.getType();
|
Type dtype = converter->convertType(atenToDtype.getType())
|
||||||
Type outType = atenToDtype.getType().cast<ValueTensorType>().getDtype();
|
.cast<RankedTensorType>()
|
||||||
Value result;
|
.getElementType();
|
||||||
if (!inType.isF32()) {
|
Value result = convertScalarToDtype(b, loc, input, dtype);
|
||||||
atenToDtype.emitError("unimplemented: non-floating point dtype");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (inType == outType)
|
|
||||||
result = input;
|
|
||||||
else if (outType.isInteger(64))
|
|
||||||
result = b.create<arith::FPToSIOp>(loc, b.getI64Type(), input);
|
|
||||||
else if (outType.isInteger(1))
|
|
||||||
result = b.create<arith::FPToSIOp>(loc, b.getI1Type(), input);
|
|
||||||
else
|
|
||||||
atenToDtype.emitError("unimplemented: unsupported target dtype");
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1808,7 +1830,7 @@ struct ConvertElementwiseOp : ConversionPattern {
|
||||||
/*iteratorTypes=*/iteratorTypes,
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||||
Value result = createLinalgPayloadCalculationForElementwiseOp(
|
Value result = createLinalgPayloadCalculationForElementwiseOp(
|
||||||
b, loc, payloadArgs, op, operands);
|
b, loc, getTypeConverter(), payloadArgs, op, operands);
|
||||||
if (!result) {
|
if (!result) {
|
||||||
hadErrorCreatingPayload = true;
|
hadErrorCreatingPayload = true;
|
||||||
return;
|
return;
|
||||||
|
@ -2161,7 +2183,7 @@ public:
|
||||||
}
|
}
|
||||||
SmallVector<Value> expectedSize = getTypeConvertedValues(
|
SmallVector<Value> expectedSize = getTypeConvertedValues(
|
||||||
rewriter, loc, typeConverter, expectedSizeTorchInt);
|
rewriter, loc, typeConverter, expectedSizeTorchInt);
|
||||||
if (expectedSize.size() != resultRank) {
|
if (resultRank != (int64_t)expectedSize.size()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "desired size list length mismatches with the result type rank");
|
op, "desired size list length mismatches with the result type rank");
|
||||||
}
|
}
|
||||||
|
|
|
@ -229,8 +229,8 @@ public:
|
||||||
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
|
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
|
||||||
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
||||||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
||||||
AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp,
|
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
|
||||||
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op>(op)) {
|
AtenFloorOp, AtenLog2Op>(op)) {
|
||||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,7 +285,7 @@ public:
|
||||||
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
|
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
|
||||||
} else if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp,
|
} else if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp,
|
||||||
AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp,
|
AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp,
|
||||||
AtenPowTensorScalarOp>(op)) {
|
AtenPowTensorScalarOp, AtenRsubScalarOp>(op)) {
|
||||||
return visitBinaryTensorScalarOp(op, operands);
|
return visitBinaryTensorScalarOp(op, operands);
|
||||||
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
||||||
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
|
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
|
||||||
|
|
|
@ -51,6 +51,8 @@ static bool isArgMemRefTypeValid(Type type) {
|
||||||
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
|
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
|
||||||
if (integerTy.isSignlessInteger(64))
|
if (integerTy.isSignlessInteger(64))
|
||||||
return true;
|
return true;
|
||||||
|
if (integerTy.isSignlessInteger(32))
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -109,7 +111,7 @@ static LogicalResult mungeFunction(
|
||||||
auto type = arg.getType();
|
auto type = arg.getType();
|
||||||
if (!isArgMemRefTypeValid(type))
|
if (!isArgMemRefTypeValid(type))
|
||||||
return emitError(arg.getLoc(),
|
return emitError(arg.getLoc(),
|
||||||
"argument must be a memref of f32, f64, i64");
|
"argument must be a memref of f32, f64, i32, i64");
|
||||||
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
|
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
|
||||||
arg.replaceAllUsesExcept(cast, cast);
|
arg.replaceAllUsesExcept(cast, cast);
|
||||||
arg.setType(getAbiTypeForMemRef(type));
|
arg.setType(getAbiTypeForMemRef(type));
|
||||||
|
@ -175,6 +177,8 @@ class MungeCallingConventions
|
||||||
};
|
};
|
||||||
|
|
||||||
// Memref return types.
|
// Memref return types.
|
||||||
|
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0),
|
||||||
|
"refbackend_consume_memref_int32_func_return");
|
||||||
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
|
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
|
||||||
"refbackend_consume_memref_int64_func_return");
|
"refbackend_consume_memref_int64_func_return");
|
||||||
createConsumeFuncReturnFunc(
|
createConsumeFuncReturnFunc(
|
||||||
|
|
|
@ -24,7 +24,7 @@ __all__ = [
|
||||||
|
|
||||||
|
|
||||||
def checkArgTypeIsSupported(ty):
|
def checkArgTypeIsSupported(ty):
|
||||||
SUPPORTED = [np.float32, np.float64, np.int64]
|
SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
|
||||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
|
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
|
||||||
|
|
||||||
class RefBackendInvoker:
|
class RefBackendInvoker:
|
||||||
|
@ -32,6 +32,10 @@ class RefBackendInvoker:
|
||||||
self.ee = ExecutionEngine(module)
|
self.ee = ExecutionEngine(module)
|
||||||
self.result = None
|
self.result = None
|
||||||
|
|
||||||
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||||
|
def consume_memref_i32_return(a):
|
||||||
|
self.result = unranked_memref_to_numpy(a, np.int32)
|
||||||
|
|
||||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||||
def consume_memref_i64_return(a):
|
def consume_memref_i64_return(a):
|
||||||
self.result = unranked_memref_to_numpy(a, np.int64)
|
self.result = unranked_memref_to_numpy(a, np.int64)
|
||||||
|
@ -56,6 +60,9 @@ class RefBackendInvoker:
|
||||||
def consume_f64_return(a):
|
def consume_f64_return(a):
|
||||||
self.result = a
|
self.result = a
|
||||||
|
|
||||||
|
self.ee.register_runtime("refbackend_consume_memref_int32_func_return",
|
||||||
|
consume_memref_i32_return)
|
||||||
|
|
||||||
self.ee.register_runtime("refbackend_consume_memref_int64_func_return",
|
self.ee.register_runtime("refbackend_consume_memref_int64_func_return",
|
||||||
consume_memref_i64_return)
|
consume_memref_i64_return)
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,9 @@ from .framework import TestResult, TraceItem
|
||||||
class TensorSummary:
|
class TensorSummary:
|
||||||
"""A summary of a tensor's contents."""
|
"""A summary of a tensor's contents."""
|
||||||
def __init__(self, tensor):
|
def __init__(self, tensor):
|
||||||
self.min = torch.min(tensor)
|
self.min = torch.min(tensor.type(torch.float64))
|
||||||
self.max = torch.max(tensor)
|
self.max = torch.max(tensor.type(torch.float64))
|
||||||
self.mean = torch.mean(tensor)
|
self.mean = torch.mean(tensor.type(torch.float64))
|
||||||
self.shape = list(tensor.shape)
|
self.shape = list(tensor.shape)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -148,10 +148,15 @@ class ValueReport:
|
||||||
if isinstance(golden, torch.Tensor):
|
if isinstance(golden, torch.Tensor):
|
||||||
if not isinstance(value, torch.Tensor):
|
if not isinstance(value, torch.Tensor):
|
||||||
return self._record_mismatch_type_failure('torch.Tensor', value)
|
return self._record_mismatch_type_failure('torch.Tensor', value)
|
||||||
|
|
||||||
if value.shape != golden.shape:
|
if value.shape != golden.shape:
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'shape ({value.shape}) is not equal to golden shape ({golden.shape})'
|
f'shape ({value.shape}) is not equal to golden shape ({golden.shape})'
|
||||||
)
|
)
|
||||||
|
if value.dtype != golden.dtype:
|
||||||
|
return self._record_failure(
|
||||||
|
f'shape ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
|
||||||
|
)
|
||||||
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
|
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
|
||||||
return self._record_failure(
|
return self._record_failure(
|
||||||
f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})'
|
f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})'
|
||||||
|
|
|
@ -981,8 +981,7 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
|
||||||
// ----
|
// ----
|
||||||
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
|
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
|
||||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>)
|
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||||
// CHECK-SAME: -> !torch.tensor {
|
|
||||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
|
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
|
||||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
|
||||||
// CHECK: return %[[CAST]] : !torch.tensor
|
// CHECK: return %[[CAST]] : !torch.tensor
|
||||||
|
@ -995,8 +994,7 @@ func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||||
// ----
|
// ----
|
||||||
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
|
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
|
||||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>)
|
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) -> !torch.tensor {
|
||||||
// CHECK-SAME: -> !torch.tensor {
|
|
||||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
|
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
|
||||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
|
||||||
// CHECK: return %[[CAST]] : !torch.tensor
|
// CHECK: return %[[CAST]] : !torch.tensor
|
||||||
|
@ -1006,15 +1004,14 @@ func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.to.dtype
|
// CHECK-LABEL: func @torch.aten.to.dtype(
|
||||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
|
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
|
||||||
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
|
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
|
||||||
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
|
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
|
||||||
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
|
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
|
||||||
// CHECK-SAME: -> !torch.tensor<[?,?],si64>
|
// CHECK-SAME: -> !torch.tensor<[?,?],si64>
|
||||||
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<[?,?],si64> to !torch.tensor
|
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<[?,?],si64> to !torch.tensor
|
||||||
// CHECK-NEXT: return %[[RES]] : !torch.tensor
|
// CHECK-NEXT: return %[[RES]] : !torch.tensor
|
||||||
|
|
||||||
func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
|
func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
|
||||||
%none = torch.constant.none
|
%none = torch.constant.none
|
||||||
%false = torch.constant.bool false
|
%false = torch.constant.bool false
|
||||||
|
@ -1025,12 +1022,10 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
|
||||||
|
|
||||||
// ----
|
// ----
|
||||||
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
|
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
|
||||||
// CHECK-SAME: %[[SELF:.*]]: !torch.int)
|
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
|
||||||
// CHECK-SAME: -> !torch.tensor {
|
|
||||||
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
|
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
|
||||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor
|
||||||
// CHECK: return %[[CAST]] : !torch.tensor
|
// CHECK: return %[[CAST]] : !torch.tensor
|
||||||
|
|
||||||
func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
|
func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
|
||||||
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
|
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
|
||||||
return %0: !torch.tensor
|
return %0: !torch.tensor
|
||||||
|
|
Loading…
Reference in New Issue