mirror of https://github.com/llvm/torch-mlir
Add convertScalarToDtype helper.
This is to facilitate scalar type conversion in the TorchToLinalg. As part of adding the helper, this PR also: - Updated `AtenAddTensorOp`, `AtenSubTensorOp` to use the helpers to support more type variants. - Added e2e type promotion testing. - Added i32 memref return/arg type to support e2e testing.pull/402/head
parent
e23cabf3a9
commit
05c4dd8e39
|
@ -33,6 +33,8 @@ from . import conv
|
|||
from . import batchnorm
|
||||
from . import quantized_models
|
||||
from . import elementwise
|
||||
from . import type_promotion
|
||||
from . import type_conversion
|
||||
from . import reduction
|
||||
from . import argmax
|
||||
from . import matmul
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TypeConversionF32ToF64Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float64)
|
||||
|
||||
@register_test_case(module_factory=lambda: TypeConversionF32ToF64Module())
|
||||
def TypeConversionF32ToF64Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
class TypeConversionF64ToF32Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float64, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float32)
|
||||
|
||||
@register_test_case(module_factory=lambda: TypeConversionF64ToF32Module())
|
||||
def TypeConversionF64ToF32Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5).type(torch.float64))
|
||||
|
||||
class TypeConversionI32ToI64Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int64)
|
||||
|
||||
@register_test_case(module_factory=lambda: TypeConversionI32ToI64Module())
|
||||
def TypeConversionI32ToI64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(5, [2, 3]).type(torch.int32))
|
||||
|
||||
class TypeConversionI64ToI32Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int32)
|
||||
|
||||
@register_test_case(module_factory=lambda: TypeConversionI64ToI32Module())
|
||||
def TypeConversionI64ToI32Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(5, [2, 3]))
|
|
@ -0,0 +1,113 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=3)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule())
|
||||
def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.randint(10, [4]).type(torch.int32),
|
||||
torch.randint(10, [4]))
|
||||
|
||||
|
||||
class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=3)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionDifferentCategoryModule())
|
||||
def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, [4]), torch.randn(4))
|
||||
|
||||
|
||||
class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=2.3)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionSameCategoryZeroRankWiderModule())
|
||||
def TypePromotionSameCategoryZeroRankWider_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand().type(torch.float64))
|
||||
|
||||
|
||||
class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=2)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionZeroRankHigherCategoryModule())
|
||||
def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, [4]), tu.rand())
|
||||
|
||||
|
||||
class TypePromotionAlphaWiderModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=2.3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule())
|
||||
def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand())
|
|
@ -1251,29 +1251,65 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Convert a scalar value to the target type. The scalar value can be an element
|
||||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||
// should be converted builtin types.
|
||||
static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||
Type dtype) {
|
||||
Type scalarType = scalar.getType();
|
||||
if (scalarType == dtype)
|
||||
return scalar;
|
||||
|
||||
static Value promoteScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype) {
|
||||
// TODO: For the integer case, we probably need the unconverted dtype to
|
||||
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
|
||||
// be able to know if we need signed or unsigned conversion.
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
if (scalar.getType().isa<mlir::FloatType>()) {
|
||||
// `scalar` will always be f64 since that is what the TypeConverter
|
||||
// converts !torch.float to.
|
||||
auto isByteOrChar = [](Type type) {
|
||||
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
|
||||
return integerTy.getWidth() == 8;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
|
||||
scalarType.isSignlessInteger(1) || dtype.isSignlessInteger(1)) {
|
||||
// TODO: Handle bool type.
|
||||
mlir::emitError(loc)
|
||||
<< "unsupported byte, char or bool type for convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
||||
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
||||
return b.create<arith::TruncFOp>(loc, scalar, dtype);
|
||||
} else {
|
||||
assert(scalar.getType().isa<mlir::IntegerType>());
|
||||
// `scalar` will always be i64 since that is what the TypeConverter
|
||||
// converts !torch.int to.
|
||||
// Only scalarFloat width < dtypeFloat width can reach here.
|
||||
return b.create<arith::ExtFOp>(loc, scalar, dtype);
|
||||
}
|
||||
assert(scalarType.isa<mlir::IntegerType>());
|
||||
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
|
||||
// unsigned handling is needed, and we checked for that case above.
|
||||
return b.create<arith::SIToFPOp>(loc, scalar, dtype);
|
||||
}
|
||||
|
||||
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
|
||||
return b.create<arith::FPToSIOp>(loc, scalar, dtype);
|
||||
assert(scalarType.isa<mlir::IntegerType>());
|
||||
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
|
||||
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
||||
return b.create<arith::TruncIOp>(loc, scalar, dtype);
|
||||
// Only scalarInteger width < dtypeInteger width can reach here.
|
||||
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
|
||||
// unsigned handling is needed, and we checked for that case above.
|
||||
return b.create<arith::ExtSIOp>(loc, scalar, dtype);
|
||||
}
|
||||
mlir::emitError(loc) << "promoteScalarToDtype for dtype " << dtype;
|
||||
return nullptr;
|
||||
|
||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
|
||||
ArrayRef<Value> operands) {
|
||||
OpBuilder &b, Location loc, TypeConverter *converter,
|
||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||
if (isa<AtenTanhOp>(op))
|
||||
return b.create<math::TanhOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenExpOp>(op))
|
||||
|
@ -1322,40 +1358,35 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
||||
AtenAddTensorOp::Adaptor adaptor(operands);
|
||||
if (add.alpha().getType().isa<Torch::FloatType>()) {
|
||||
add.emitError("unimplemented: !torch.float 'alpha'");
|
||||
return nullptr;
|
||||
Type dtype = converter->convertType(add.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||
return b.create<arith::AddFOp>(loc, lhs, scaled);
|
||||
} else {
|
||||
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
|
||||
return b.create<arith::AddIOp>(loc, lhs, scaled);
|
||||
}
|
||||
if (!add.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
add.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value alphaFloat = b.create<arith::SIToFPOp>(loc, payloadArgs[0].getType(),
|
||||
adaptor.alpha());
|
||||
Value scaled = b.create<arith::MulFOp>(loc, payloadArgs[1], alphaFloat);
|
||||
return b.create<arith::AddFOp>(loc, payloadArgs[0], scaled);
|
||||
}
|
||||
if (auto sub = dyn_cast<AtenSubTensorOp>(op)) {
|
||||
AtenSubTensorOp::Adaptor adaptor(operands);
|
||||
if (sub.alpha().getType().isa<Torch::FloatType>()) {
|
||||
sub.emitError("unimplemented: !torch.float 'alpha'");
|
||||
return nullptr;
|
||||
Type dtype = converter->convertType(sub.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||
return b.create<arith::SubFOp>(loc, lhs, scaled);
|
||||
} else {
|
||||
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
|
||||
return b.create<arith::SubIOp>(loc, lhs, scaled);
|
||||
}
|
||||
if (!sub.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
sub.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value alphaFloat = b.create<arith::SIToFPOp>(loc, payloadArgs[0].getType(),
|
||||
adaptor.alpha());
|
||||
Value scaled = b.create<arith::MulFOp>(loc, payloadArgs[1], alphaFloat);
|
||||
|
||||
return b.create<arith::SubFOp>(loc, payloadArgs[0], scaled);
|
||||
}
|
||||
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
|
||||
if (!mul.getType()
|
||||
|
@ -1386,7 +1417,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return nullptr;
|
||||
}
|
||||
Type dtype = pow.self().getType().cast<ValueTensorType>().getDtype();
|
||||
Value expPromoted = promoteScalarToDtype(b, loc, operands[1], dtype);
|
||||
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
|
||||
}
|
||||
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
|
||||
|
@ -1430,7 +1461,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
|
||||
auto dtype = clamp.getType().cast<ValueTensorType>().getDtype();
|
||||
Type dtype = converter->convertType(clamp.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
clamp.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
|
@ -1445,13 +1478,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
auto result = payloadArgs[0];
|
||||
if (!min.getType().isa<Torch::NoneType>()) {
|
||||
auto minPromoted = promoteScalarToDtype(b, loc, min, dtype);
|
||||
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
result, minPromoted);
|
||||
result = b.create<SelectOp>(loc, pred, minPromoted, result);
|
||||
}
|
||||
if (!max.getType().isa<Torch::NoneType>()) {
|
||||
auto maxPromoted = promoteScalarToDtype(b, loc, max, dtype);
|
||||
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
result, maxPromoted);
|
||||
result = b.create<SelectOp>(loc, pred, maxPromoted, result);
|
||||
|
@ -1459,36 +1492,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return result;
|
||||
}
|
||||
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
|
||||
if (!rsub.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
Type dtype = converter->convertType(rsub.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
rsub.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value self = payloadArgs[0];
|
||||
Value other = promoteScalarToDtype(b, loc, operands[1], self.getType());
|
||||
Value alpha = promoteScalarToDtype(b, loc, operands[2], self.getType());
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
|
||||
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
|
||||
return b.create<arith::SubFOp>(loc, other, mult);
|
||||
}
|
||||
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||
Value input = payloadArgs[0];
|
||||
Type inType = input.getType();
|
||||
Type outType = atenToDtype.getType().cast<ValueTensorType>().getDtype();
|
||||
Value result;
|
||||
if (!inType.isF32()) {
|
||||
atenToDtype.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
if (inType == outType)
|
||||
result = input;
|
||||
else if (outType.isInteger(64))
|
||||
result = b.create<arith::FPToSIOp>(loc, b.getI64Type(), input);
|
||||
else if (outType.isInteger(1))
|
||||
result = b.create<arith::FPToSIOp>(loc, b.getI1Type(), input);
|
||||
else
|
||||
atenToDtype.emitError("unimplemented: unsupported target dtype");
|
||||
Type dtype = converter->convertType(atenToDtype.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value result = convertScalarToDtype(b, loc, input, dtype);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -1808,7 +1830,7 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
Value result = createLinalgPayloadCalculationForElementwiseOp(
|
||||
b, loc, payloadArgs, op, operands);
|
||||
b, loc, getTypeConverter(), payloadArgs, op, operands);
|
||||
if (!result) {
|
||||
hadErrorCreatingPayload = true;
|
||||
return;
|
||||
|
@ -2161,7 +2183,7 @@ public:
|
|||
}
|
||||
SmallVector<Value> expectedSize = getTypeConvertedValues(
|
||||
rewriter, loc, typeConverter, expectedSizeTorchInt);
|
||||
if (expectedSize.size() != resultRank) {
|
||||
if (resultRank != (int64_t)expectedSize.size()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "desired size list length mismatches with the result type rank");
|
||||
}
|
||||
|
|
|
@ -229,8 +229,8 @@ public:
|
|||
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
|
||||
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
||||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
||||
AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp,
|
||||
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op>(op)) {
|
||||
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenLog2Op>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
@ -285,7 +285,7 @@ public:
|
|||
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
|
||||
} else if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp,
|
||||
AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp,
|
||||
AtenPowTensorScalarOp>(op)) {
|
||||
AtenPowTensorScalarOp, AtenRsubScalarOp>(op)) {
|
||||
return visitBinaryTensorScalarOp(op, operands);
|
||||
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
|
||||
|
|
|
@ -51,6 +51,8 @@ static bool isArgMemRefTypeValid(Type type) {
|
|||
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
|
||||
if (integerTy.isSignlessInteger(64))
|
||||
return true;
|
||||
if (integerTy.isSignlessInteger(32))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
@ -109,7 +111,7 @@ static LogicalResult mungeFunction(
|
|||
auto type = arg.getType();
|
||||
if (!isArgMemRefTypeValid(type))
|
||||
return emitError(arg.getLoc(),
|
||||
"argument must be a memref of f32, f64, i64");
|
||||
"argument must be a memref of f32, f64, i32, i64");
|
||||
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
|
||||
arg.replaceAllUsesExcept(cast, cast);
|
||||
arg.setType(getAbiTypeForMemRef(type));
|
||||
|
@ -175,6 +177,8 @@ class MungeCallingConventions
|
|||
};
|
||||
|
||||
// Memref return types.
|
||||
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0),
|
||||
"refbackend_consume_memref_int32_func_return");
|
||||
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
|
||||
"refbackend_consume_memref_int64_func_return");
|
||||
createConsumeFuncReturnFunc(
|
||||
|
|
|
@ -24,7 +24,7 @@ __all__ = [
|
|||
|
||||
|
||||
def checkArgTypeIsSupported(ty):
|
||||
SUPPORTED = [np.float32, np.float64, np.int64]
|
||||
SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
|
||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
|
||||
|
||||
class RefBackendInvoker:
|
||||
|
@ -32,6 +32,10 @@ class RefBackendInvoker:
|
|||
self.ee = ExecutionEngine(module)
|
||||
self.result = None
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_memref_i32_return(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.int32)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_memref_i64_return(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.int64)
|
||||
|
@ -56,6 +60,9 @@ class RefBackendInvoker:
|
|||
def consume_f64_return(a):
|
||||
self.result = a
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_memref_int32_func_return",
|
||||
consume_memref_i32_return)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_memref_int64_func_return",
|
||||
consume_memref_i64_return)
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@ from .framework import TestResult, TraceItem
|
|||
class TensorSummary:
|
||||
"""A summary of a tensor's contents."""
|
||||
def __init__(self, tensor):
|
||||
self.min = torch.min(tensor)
|
||||
self.max = torch.max(tensor)
|
||||
self.mean = torch.mean(tensor)
|
||||
self.min = torch.min(tensor.type(torch.float64))
|
||||
self.max = torch.max(tensor.type(torch.float64))
|
||||
self.mean = torch.mean(tensor.type(torch.float64))
|
||||
self.shape = list(tensor.shape)
|
||||
|
||||
def __str__(self):
|
||||
|
@ -148,10 +148,15 @@ class ValueReport:
|
|||
if isinstance(golden, torch.Tensor):
|
||||
if not isinstance(value, torch.Tensor):
|
||||
return self._record_mismatch_type_failure('torch.Tensor', value)
|
||||
|
||||
if value.shape != golden.shape:
|
||||
return self._record_failure(
|
||||
f'shape ({value.shape}) is not equal to golden shape ({golden.shape})'
|
||||
)
|
||||
if value.dtype != golden.dtype:
|
||||
return self._record_failure(
|
||||
f'shape ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
|
||||
)
|
||||
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
|
||||
return self._record_failure(
|
||||
f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})'
|
||||
|
|
|
@ -981,8 +981,7 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
|
|||
// ----
|
||||
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>)
|
||||
// CHECK-SAME: -> !torch.tensor {
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
|
@ -995,8 +994,7 @@ func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>,
|
|||
// ----
|
||||
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>)
|
||||
// CHECK-SAME: -> !torch.tensor {
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
|
@ -1006,15 +1004,14 @@ func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>,
|
|||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.to.dtype
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
|
||||
// CHECK-LABEL: func @torch.aten.to.dtype(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
|
||||
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
|
||||
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
|
||||
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
|
||||
// CHECK-SAME: -> !torch.tensor<[?,?],si64>
|
||||
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<[?,?],si64> to !torch.tensor
|
||||
// CHECK-NEXT: return %[[RES]] : !torch.tensor
|
||||
|
||||
func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
|
@ -1025,12 +1022,10 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
|
|||
|
||||
// ----
|
||||
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.int)
|
||||
// CHECK-SAME: -> !torch.tensor {
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
|
||||
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
|
||||
func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
|
||||
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
|
||||
return %0: !torch.tensor
|
||||
|
|
Loading…
Reference in New Issue