mirror of https://github.com/llvm/torch-mlir
[Torch] fix toBuiltinTensor() (#3415)
* Let `toBuiltinTensor()` reflects the original dtype of `!torch.vtensor`. * Backend handles dtype conversion themselves.pull/3437/head
parent
75af64fc12
commit
689efc8917
|
@ -737,7 +737,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
std::numeric_limits<float>::lowest()))
|
std::numeric_limits<float>::lowest()))
|
||||||
return failure();
|
return failure();
|
||||||
auto minSplatAttr = SplatElementsAttr::get(
|
auto minSplatAttr = SplatElementsAttr::get(
|
||||||
resultType.toBuiltinTensor().clone(resultDtype),
|
resultType.toBuiltinTensor(),
|
||||||
rewriter.getFloatAttr(resultDtype, minValue));
|
rewriter.getFloatAttr(resultDtype, minValue));
|
||||||
min = rewriter.create<Torch::ValueTensorLiteralOp>(
|
min = rewriter.create<Torch::ValueTensorLiteralOp>(
|
||||||
binder.getLoc(), resultType, minSplatAttr);
|
binder.getLoc(), resultType, minSplatAttr);
|
||||||
|
@ -748,7 +748,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
std::numeric_limits<float>::max()))
|
std::numeric_limits<float>::max()))
|
||||||
return failure();
|
return failure();
|
||||||
auto maxSplatAttr = SplatElementsAttr::get(
|
auto maxSplatAttr = SplatElementsAttr::get(
|
||||||
resultType.toBuiltinTensor().clone(resultDtype),
|
resultType.toBuiltinTensor(),
|
||||||
rewriter.getFloatAttr(resultDtype, maxValue));
|
rewriter.getFloatAttr(resultDtype, maxValue));
|
||||||
max = rewriter.create<Torch::ValueTensorLiteralOp>(
|
max = rewriter.create<Torch::ValueTensorLiteralOp>(
|
||||||
binder.getLoc(), resultType, maxSplatAttr);
|
binder.getLoc(), resultType, maxSplatAttr);
|
||||||
|
@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
if (binder.op->hasAttr("torch.onnx.value_float") &&
|
if (binder.op->hasAttr("torch.onnx.value_float") &&
|
||||||
!binder.f32FloatAttr(floatValue, "value_float", 0.0)) {
|
!binder.f32FloatAttr(floatValue, "value_float", 0.0)) {
|
||||||
auto splatAttr =
|
auto splatAttr =
|
||||||
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
|
SplatElementsAttr::get(resultType.toBuiltinTensor(),
|
||||||
rewriter.getFloatAttr(dtype, floatValue));
|
rewriter.getFloatAttr(dtype, floatValue));
|
||||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||||
binder.op, resultType, splatAttr);
|
binder.op, resultType, splatAttr);
|
||||||
|
@ -872,7 +872,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
if (binder.op->hasAttr("torch.onnx.value_int") &&
|
if (binder.op->hasAttr("torch.onnx.value_int") &&
|
||||||
!binder.s64IntegerAttr(intValue, "value_int", 0)) {
|
!binder.s64IntegerAttr(intValue, "value_int", 0)) {
|
||||||
auto splatAttr =
|
auto splatAttr =
|
||||||
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
|
SplatElementsAttr::get(resultType.toBuiltinTensor(),
|
||||||
rewriter.getIntegerAttr(dtype, intValue));
|
rewriter.getIntegerAttr(dtype, intValue));
|
||||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||||
binder.op, resultType, splatAttr);
|
binder.op, resultType, splatAttr);
|
||||||
|
@ -932,8 +932,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
for (auto intVal : intValues) {
|
for (auto intVal : intValues) {
|
||||||
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal));
|
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal));
|
||||||
}
|
}
|
||||||
auto attr = DenseElementsAttr::get(
|
auto attr =
|
||||||
resultType.toBuiltinTensor().clone(dtype), apValues);
|
DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues);
|
||||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||||
binder.op, resultType, attr);
|
binder.op, resultType, attr);
|
||||||
return success();
|
return success();
|
||||||
|
@ -2272,8 +2272,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
// Extract the fill value and dtype
|
// Extract the fill value and dtype
|
||||||
// ONNX requires value attr to be a tensor
|
// ONNX requires value attr to be a tensor
|
||||||
if (!attr) {
|
if (!attr) {
|
||||||
attr = DenseElementsAttr::get(
|
attr =
|
||||||
resultType.toBuiltinTensor().clone(resultDType),
|
DenseElementsAttr::get(resultType.toBuiltinTensor(),
|
||||||
rewriter.getFloatAttr(resultDType, 0.0));
|
rewriter.getFloatAttr(resultDType, 0.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -146,12 +146,11 @@ public:
|
||||||
"mismatching contracting dimension for torch.aten.mm"));
|
"mismatching contracting dimension for torch.aten.mm"));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultTy = cast<ValueTensorType>(op.getType());
|
TensorType resultType =
|
||||||
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type elementType = resultType.getElementType();
|
||||||
Type elementType = cast<TensorType>(newResultType).getElementType();
|
auto accumulatorDType = getDefaultAccType(rewriter, elementType);
|
||||||
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy);
|
if (accumulatorDType != resultType.getElementType()) {
|
||||||
if (accumulatorDType != resultDTy) {
|
|
||||||
elementType = accumulatorDType;
|
elementType = accumulatorDType;
|
||||||
}
|
}
|
||||||
Value zeroFill = createZeroInitTensor(
|
Value zeroFill = createZeroInitTensor(
|
||||||
|
@ -197,18 +196,16 @@ public:
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (accumulatorDType != resultDTy) {
|
if (accumulatorDType != resultType.getElementType()) {
|
||||||
Type resultElementType =
|
|
||||||
cast<RankedTensorType>(newResultType).getElementType();
|
|
||||||
matmul = torch_to_linalg::convertTensorToElementType(
|
matmul = torch_to_linalg::convertTensorToElementType(
|
||||||
rewriter, loc, matmul, resultElementType);
|
rewriter, loc, matmul, resultType.getElementType());
|
||||||
}
|
}
|
||||||
// When constructed with just dynamic sizes, EmptyOp will have a result
|
// When constructed with just dynamic sizes, EmptyOp will have a result
|
||||||
// type which has all `?`'s for dimensions, which might not be the result
|
// type which has all `?`'s for dimensions, which might not be the result
|
||||||
// type of `op`. The constraints on later linalg ops means that the result
|
// type of `op`. The constraints on later linalg ops means that the result
|
||||||
// of the MatmulOp will have this type too. So cast it to the desired type
|
// of the MatmulOp will have this type too. So cast it to the desired type
|
||||||
// so that in the end we have the original result type.
|
// so that in the end we have the original result type.
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, matmul);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1311,7 +1311,7 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto dty = resultTy.getDtype();
|
auto dty = resultTy.getDtype();
|
||||||
auto resultBTy = resultTy.toBuiltinTensor().clone(dty);
|
auto resultBTy = resultTy.toBuiltinTensor();
|
||||||
|
|
||||||
auto fpTy = dyn_cast<mlir::FloatType>(dty);
|
auto fpTy = dyn_cast<mlir::FloatType>(dty);
|
||||||
auto intTy = dyn_cast<mlir::IntegerType>(dty);
|
auto intTy = dyn_cast<mlir::IntegerType>(dty);
|
||||||
|
@ -1521,7 +1521,7 @@ OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
if (!ty || !ty.hasDtype() || !ty.hasSizes())
|
if (!ty || !ty.hasDtype() || !ty.hasSizes())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto bty = ty.toBuiltinTensor().clone(ty.getDtype());
|
auto bty = ty.toBuiltinTensor();
|
||||||
if (!bty.hasStaticShape())
|
if (!bty.hasStaticShape())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
@ -1635,7 +1635,6 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto ctx = lhs.getContext();
|
auto ctx = lhs.getContext();
|
||||||
auto resultETy = resultTy.getDtype();
|
|
||||||
auto tensorETy = cast<RankedTensorType>(lhs.getType()).getElementType();
|
auto tensorETy = cast<RankedTensorType>(lhs.getType()).getElementType();
|
||||||
if (lhs.isSplat()) {
|
if (lhs.isSplat()) {
|
||||||
if (auto intAttr = dyn_cast<IntegerAttr>(rhs)) {
|
if (auto intAttr = dyn_cast<IntegerAttr>(rhs)) {
|
||||||
|
@ -1647,8 +1646,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
||||||
unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign);
|
unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign);
|
||||||
auto resultBool = intFolder(tensorAP, scalarAP, unsign);
|
auto resultBool = intFolder(tensorAP, scalarAP, unsign);
|
||||||
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
|
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
|
||||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
|
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP);
|
||||||
resultAP);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
|
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
|
||||||
|
@ -1657,8 +1655,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
||||||
auto resultBool =
|
auto resultBool =
|
||||||
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
|
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
|
||||||
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
|
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
|
||||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
|
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP);
|
||||||
resultAP);
|
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1681,8 +1678,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
||||||
auto resultBool = intFolder(tensorAP, scalarAP, unsign);
|
auto resultBool = intFolder(tensorAP, scalarAP, unsign);
|
||||||
values.push_back(resultBool);
|
values.push_back(resultBool);
|
||||||
}
|
}
|
||||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
|
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values);
|
||||||
values);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
|
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
|
||||||
|
@ -1693,8 +1689,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
||||||
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
|
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
|
||||||
values.push_back(resultBool);
|
values.push_back(resultBool);
|
||||||
}
|
}
|
||||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
|
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values);
|
||||||
values);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1844,7 +1839,7 @@ static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand,
|
||||||
if (!fpTy && !intTy)
|
if (!fpTy && !intTy)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype());
|
auto resultBTy = resultTy.toBuiltinTensor();
|
||||||
bool splat = operand.isSplat();
|
bool splat = operand.isSplat();
|
||||||
bool withinMaxFold =
|
bool withinMaxFold =
|
||||||
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold;
|
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold;
|
||||||
|
@ -2192,7 +2187,7 @@ OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto selfTy = cast<ShapedType>(self.getType());
|
auto selfTy = cast<ShapedType>(self.getType());
|
||||||
auto bty = ty.toBuiltinTensor().clone(ty.getDtype());
|
auto bty = ty.toBuiltinTensor();
|
||||||
if (!bty.hasStaticShape())
|
if (!bty.hasStaticShape())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
@ -2656,8 +2651,7 @@ LogicalResult AtenSortOp::fold(FoldAdaptor adaptor,
|
||||||
|
|
||||||
if (!indicesTensorType.hasDtype())
|
if (!indicesTensorType.hasDtype())
|
||||||
return failure();
|
return failure();
|
||||||
auto indicesType =
|
auto indicesType = indicesTensorType.toBuiltinTensor();
|
||||||
indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype());
|
|
||||||
if (!indicesType || !indicesType.hasStaticShape())
|
if (!indicesType || !indicesType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -3612,8 +3606,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
if (input && input.isSplat())
|
if (input && input.isSplat())
|
||||||
return DenseElementsAttr::get(
|
return DenseElementsAttr::get(outType.toBuiltinTensor(),
|
||||||
outType.toBuiltinTensor().clone(inType.getDtype()),
|
|
||||||
input.getSplatValue<Attribute>());
|
input.getSplatValue<Attribute>());
|
||||||
|
|
||||||
int count = 1;
|
int count = 1;
|
||||||
|
@ -3652,8 +3645,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
for (int i = begin; i < limit; i += stride)
|
for (int i = begin; i < limit; i += stride)
|
||||||
values.push_back(input.getValues<Attribute>()[i]);
|
values.push_back(input.getValues<Attribute>()[i]);
|
||||||
|
|
||||||
return DenseElementsAttr::get(
|
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
|
||||||
outType.toBuiltinTensor().clone(inType.getDtype()), values);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the input and output shapes are the same we can just fold:
|
// If the input and output shapes are the same we can just fold:
|
||||||
|
@ -3923,7 +3915,7 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
Type eTy = resultTy.getDtype();
|
Type eTy = resultTy.getDtype();
|
||||||
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
|
ShapedType shapedTy = resultTy.toBuiltinTensor();
|
||||||
|
|
||||||
SmallVector<int64_t> data;
|
SmallVector<int64_t> data;
|
||||||
if (matchPattern(getData(), m_TorchListOfConstantInts(data)) &&
|
if (matchPattern(getData(), m_TorchListOfConstantInts(data)) &&
|
||||||
|
@ -3944,7 +3936,7 @@ OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) {
|
||||||
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
Type eTy = resultTy.getDtype();
|
Type eTy = resultTy.getDtype();
|
||||||
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
|
ShapedType shapedTy = resultTy.toBuiltinTensor();
|
||||||
|
|
||||||
int64_t data;
|
int64_t data;
|
||||||
if (matchPattern(getT(), m_TorchConstantInt(&data))) {
|
if (matchPattern(getT(), m_TorchConstantInt(&data))) {
|
||||||
|
@ -3964,7 +3956,7 @@ OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) {
|
||||||
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
Type eTy = resultTy.getDtype();
|
Type eTy = resultTy.getDtype();
|
||||||
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
|
ShapedType shapedTy = resultTy.toBuiltinTensor();
|
||||||
|
|
||||||
double data;
|
double data;
|
||||||
if (matchPattern(getT(), m_TorchConstantFloat(&data))) {
|
if (matchPattern(getT(), m_TorchConstantFloat(&data))) {
|
||||||
|
@ -4137,7 +4129,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
|
||||||
: selfAttr.getValues<Attribute>()[indexInt];
|
: selfAttr.getValues<Attribute>()[indexInt];
|
||||||
|
|
||||||
auto dty = resultTy.getDtype();
|
auto dty = resultTy.getDtype();
|
||||||
auto attrTy = resultTy.toBuiltinTensor().clone(dty);
|
auto attrTy = resultTy.toBuiltinTensor();
|
||||||
if (auto floatAttr = dyn_cast<FloatAttr>(splattr))
|
if (auto floatAttr = dyn_cast<FloatAttr>(splattr))
|
||||||
return DenseElementsAttr::get(
|
return DenseElementsAttr::get(
|
||||||
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));
|
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));
|
||||||
|
@ -4330,7 +4322,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
|
||||||
if (!valueDense.isSplat())
|
if (!valueDense.isSplat())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto splattr = valueDense.getSplatValue<Attribute>();
|
auto splattr = valueDense.getSplatValue<Attribute>();
|
||||||
auto attrty = ty.toBuiltinTensor().clone(dty);
|
auto attrty = ty.toBuiltinTensor();
|
||||||
return DenseElementsAttr::get(attrty, splattr);
|
return DenseElementsAttr::get(attrty, splattr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4338,7 +4330,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
|
||||||
if (!isa<mlir::IntegerType>(dty))
|
if (!isa<mlir::IntegerType>(dty))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
int64_t intval = intAttr.getInt();
|
int64_t intval = intAttr.getInt();
|
||||||
auto attrty = ty.toBuiltinTensor().clone(dty);
|
auto attrty = ty.toBuiltinTensor();
|
||||||
return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval));
|
return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4346,7 +4338,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
|
||||||
if (!isa<mlir::FloatType>(dty))
|
if (!isa<mlir::FloatType>(dty))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
double dblval = fpAttr.getValueAsDouble();
|
double dblval = fpAttr.getValueAsDouble();
|
||||||
auto attrty = ty.toBuiltinTensor().clone(dty);
|
auto attrty = ty.toBuiltinTensor();
|
||||||
return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval));
|
return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -453,12 +453,7 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
|
static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
|
||||||
if (auto floatType = dyn_cast<mlir::FloatType>(dtype)) {
|
if (isa<mlir::FloatType, IntegerType, mlir::ComplexType>(dtype)) {
|
||||||
return dtype;
|
|
||||||
} else if (auto integerType = dyn_cast<IntegerType>(dtype)) {
|
|
||||||
return IntegerType::get(context, integerType.getWidth(),
|
|
||||||
IntegerType::Signless);
|
|
||||||
} else if (isa<mlir::ComplexType>(dtype)) {
|
|
||||||
return dtype;
|
return dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -480,11 +475,11 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
|
||||||
TensorType ValueTensorType::toBuiltinTensor() const {
|
TensorType ValueTensorType::toBuiltinTensor() const {
|
||||||
if (!hasDtype())
|
if (!hasDtype())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (!hasSizes())
|
|
||||||
return UnrankedTensorType::get(getDtype());
|
|
||||||
Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype());
|
Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype());
|
||||||
if (!elementType)
|
if (!elementType)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
if (!hasSizes())
|
||||||
|
return UnrankedTensorType::get(elementType);
|
||||||
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType,
|
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType,
|
||||||
getOptionalSparsity());
|
getOptionalSparsity());
|
||||||
}
|
}
|
||||||
|
|
|
@ -164,7 +164,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion(
|
||||||
ConversionTarget &target, TypeConverter &typeConverter) {
|
ConversionTarget &target, TypeConverter &typeConverter) {
|
||||||
auto valueTensorTypeConversion =
|
auto valueTensorTypeConversion =
|
||||||
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
||||||
return type.toBuiltinTensor();
|
auto builtinType = type.toBuiltinTensor();
|
||||||
|
if (!builtinType)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
// convert any integer type to signless
|
||||||
|
if (type.getDtype().isInteger()) {
|
||||||
|
return builtinType.clone(IntegerType::get(
|
||||||
|
builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(),
|
||||||
|
IntegerType::Signless));
|
||||||
|
}
|
||||||
|
|
||||||
|
return builtinType;
|
||||||
};
|
};
|
||||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
|
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
|
||||||
valueTensorTypeConversion);
|
valueTensorTypeConversion);
|
||||||
|
@ -180,9 +191,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo(
|
||||||
auto valueTensorTypeConversion =
|
auto valueTensorTypeConversion =
|
||||||
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
||||||
auto builtinType = type.toBuiltinTensor();
|
auto builtinType = type.toBuiltinTensor();
|
||||||
|
if (!builtinType)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
// convert signed integer type to signless, keep unsigned as unsigned
|
||||||
if (type.getDtype().isUnsignedInteger()) {
|
if (type.getDtype().isUnsignedInteger()) {
|
||||||
return builtinType.clone(type.getDtype());
|
return builtinType.clone(type.getDtype());
|
||||||
|
} else if (type.getDtype().isSignedInteger()) {
|
||||||
|
return builtinType.clone(IntegerType::get(
|
||||||
|
builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(),
|
||||||
|
IntegerType::Signless));
|
||||||
}
|
}
|
||||||
|
|
||||||
return builtinType;
|
return builtinType;
|
||||||
};
|
};
|
||||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
|
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
|
||||||
|
|
Loading…
Reference in New Issue