Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130)

We should prefer functional style as the method style is deprecated
https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
(https://mlir.llvm.org/deprecation/)
pull/3141/head
penguin_wwy 2024-04-11 21:47:35 +08:00 committed by GitHub
parent 308c45e61a
commit d4a30b7e67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 405 additions and 409 deletions

View File

@ -178,7 +178,7 @@ struct OpBinder {
} }
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) { for (auto element : arrayAttr) {
auto integerAttr = element.dyn_cast<IntegerAttr>(); auto integerAttr = dyn_cast<IntegerAttr>(element);
if (!integerAttr) if (!integerAttr)
return failure(); return failure();
IntegerType t = cast<IntegerType>(integerAttr.getType()); IntegerType t = cast<IntegerType>(integerAttr.getType());
@ -200,7 +200,7 @@ struct OpBinder {
return success(); return success();
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) { for (auto element : arrayAttr) {
StringAttr stringAttr = element.dyn_cast<StringAttr>(); StringAttr stringAttr = dyn_cast<StringAttr>(element);
if (!stringAttr) if (!stringAttr)
return failure(); return failure();
values.push_back(stringAttr.getValue().str()); values.push_back(stringAttr.getValue().str());

View File

@ -94,7 +94,7 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
// Compute the knowledge based on the inferred type. // Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType(); inferredKnowledge.dtype = cast<ShapedType>(result_ty).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank(); inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) { if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) { for (auto dim : predictedShape.getDims()) {

View File

@ -1287,7 +1287,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.getLoc(), axisScalar, finalOffset); binder.getLoc(), axisScalar, finalOffset);
Torch::BaseTensorType resultTensorType = Torch::BaseTensorType resultTensorType =
resultType.cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) { if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype"); binder.op, "expected result type to have a dtype");
@ -1899,7 +1899,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// If its a dense resource attr we need to convert to a dense type: // If its a dense resource attr we need to convert to a dense type:
if (DenseResourceElementsAttr rattr = if (DenseResourceElementsAttr rattr =
attr.dyn_cast_or_null<DenseResourceElementsAttr>()) { dyn_cast_or_null<DenseResourceElementsAttr>(attr)) {
// Bytes are stored in little endian order. Big endian support will // Bytes are stored in little endian order. Big endian support will
// require swizzling. // require swizzling.
if (!Endian::little) { if (!Endian::little) {
@ -1916,7 +1916,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Attribute splattr; Attribute splattr;
if (isa<SplatElementsAttr>(attr)) { if (isa<SplatElementsAttr>(attr)) {
auto denseAttr = attr.cast<DenseElementsAttr>(); auto denseAttr = cast<DenseElementsAttr>(attr);
splattr = denseAttr.getSplatValue<Attribute>(); splattr = denseAttr.getSplatValue<Attribute>();
} }

View File

@ -1366,7 +1366,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// set the splitted axis to variable shape // set the splitted axis to variable shape
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes()); llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
for (auto result : binder.op->getResultTypes()) { for (auto result : binder.op->getResultTypes()) {
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim]; int64_t d = cast<Torch::ValueTensorType>(result).getSizes()[dim];
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
} }
@ -1437,7 +1437,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes()); llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
for (auto result : binder.op->getResultTypes()) { for (auto result : binder.op->getResultTypes()) {
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim]; int64_t d = cast<Torch::ValueTensorType>(result).getSizes()[dim];
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
} }

View File

@ -272,9 +272,9 @@ public:
convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType); convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType);
Value operandB = Value operandB =
convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType); convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType);
if (resultType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(resultType)) {
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB); rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB);
} else if (resultType.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(resultType)) {
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB); rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB);
} else { } else {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(

View File

@ -1881,7 +1881,7 @@ public:
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType = input.getType().cast<RankedTensorType>();
auto inputElementType = getElementTypeOrSelf(input.getType()); auto inputElementType = getElementTypeOrSelf(input.getType());
if (!inputElementType.isa<ComplexType>()) { if (!isa<ComplexType>(inputElementType)) {
return op.emitError("only ComplexType is allowed as input type"); return op.emitError("only ComplexType is allowed as input type");
} }
Type elementType = resultType.getElementType(); Type elementType = resultType.getElementType();

View File

@ -131,7 +131,7 @@ public:
auto resultTy = op.getType().cast<ValueTensorType>(); auto resultTy = op.getType().cast<ValueTensorType>();
auto resultDTy = resultTy.toBuiltinTensor().getElementType(); auto resultDTy = resultTy.toBuiltinTensor().getElementType();
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType(); Type elementType = cast<TensorType>(newResultType).getElementType();
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy); auto accumulatorDType = getDefaultAccType(rewriter, resultDTy);
if (accumulatorDType != resultDTy) { if (accumulatorDType != resultDTy) {
elementType = accumulatorDType; elementType = accumulatorDType;
@ -201,7 +201,7 @@ public:
if (accumulatorDType != resultDTy) { if (accumulatorDType != resultDTy) {
Type resultElementType = Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(newResultType).getElementType();
matmul = torch_to_linalg::convertTensorToElementType( matmul = torch_to_linalg::convertTensorToElementType(
rewriter, loc, matmul, resultElementType); rewriter, loc, matmul, resultElementType);
} }
@ -307,7 +307,7 @@ public:
unsigned rhsRank = rhsType.getRank(); unsigned rhsRank = rhsType.getRank();
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
auto resultType = newResultType.cast<RankedTensorType>(); auto resultType = cast<RankedTensorType>(newResultType);
Type elementType = resultType.getElementType(); Type elementType = resultType.getElementType();
// The different cases of torch_matmul op is mentioned here: // The different cases of torch_matmul op is mentioned here:
@ -600,9 +600,9 @@ public:
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>(); RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
Type resultElementType = Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(newResultType).getElementType();
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType(); Type lhsElementType = cast<RankedTensorType>(lhsType).getElementType();
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType(); Type rhsElementType = cast<RankedTensorType>(rhsType).getElementType();
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) { if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -712,9 +712,9 @@ public:
auto weightDTy = weight.getType().cast<RankedTensorType>().getElementType(); auto weightDTy = weight.getType().cast<RankedTensorType>().getElementType();
auto resultDTy = resultTy.toBuiltinTensor().getElementType(); auto resultDTy = resultTy.toBuiltinTensor().getElementType();
if (!inputDTy.isa<mlir::FloatType, mlir::IntegerType>() || if (!isa<mlir::FloatType, mlir::IntegerType>(inputDTy) ||
!weightDTy.isa<mlir::FloatType, mlir::IntegerType>() || !isa<mlir::FloatType, mlir::IntegerType>(weightDTy) ||
!resultDTy.isa<mlir::FloatType, mlir::IntegerType>()) !isa<mlir::FloatType, mlir::IntegerType>(resultDTy))
return op.emitError("unimplemented: non-fp not-int type"); return op.emitError("unimplemented: non-fp not-int type");
size_t inRank = input.getType().cast<RankedTensorType>().getRank(); size_t inRank = input.getType().cast<RankedTensorType>().getRank();
size_t numSpatialDims = inRank - 2; size_t numSpatialDims = inRank - 2;
@ -790,9 +790,8 @@ public:
SmallVector<Value> outDims{inBatch, weightBatch}; SmallVector<Value> outDims{inBatch, weightBatch};
Value paddedInput; Value paddedInput;
if (transposed) { if (transposed) {
if (!inputDTy.isa<mlir::FloatType>() || if (!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy) ||
!weightDTy.isa<mlir::FloatType>() || !isa<mlir::FloatType>(resultDTy))
!resultDTy.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "transpose does not support non-fp type yet"); op, "transpose does not support non-fp type yet");
@ -927,10 +926,10 @@ public:
accumulatorDType); accumulatorDType);
if (bias.getType().isa<Torch::NoneType>()) { if (bias.getType().isa<Torch::NoneType>()) {
Value c0; Value c0;
if (accumulatorDType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(accumulatorDType)) {
c0 = rewriter.create<arith::ConstantOp>( c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(accumulatorDType, 0.0)); loc, FloatAttr::get(accumulatorDType, 0.0));
} else if (accumulatorDType.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(accumulatorDType)) {
c0 = rewriter.create<arith::ConstantOp>( c0 = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(accumulatorDType, 0)); loc, IntegerAttr::get(accumulatorDType, 0));
} }
@ -1021,7 +1020,7 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) { if (accumulatorDType != resultDTy) {
Type resultElementType = Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType); resultElementType);
} }
@ -1081,7 +1080,7 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) { if (accumulatorDType != resultDTy) {
Type resultElementType = Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType); resultElementType);
} }
@ -1125,7 +1124,7 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) { if (accumulatorDType != resultDTy) {
Type resultElementType = Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType); resultElementType);
} }
@ -1203,7 +1202,7 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) { if (accumulatorDType != resultDTy) {
Type resultElementType = Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType); resultElementType);
} }

View File

@ -154,7 +154,7 @@ static LogicalResult createPoolingOp(
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) { SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
Location loc = op->getLoc(); Location loc = op->getLoc();
Type elementType = self.getType().cast<RankedTensorType>().getElementType(); Type elementType = self.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput) if (!isa<mlir::FloatType>(elementType) && !supportNonFPInput)
return op->emitError("unimplemented: non-floating point type"); return op->emitError("unimplemented: non-floating point type");
Value initValue = Value initValue =
@ -217,7 +217,7 @@ private:
Type elementType = self.getType().cast<RankedTensorType>().getElementType(); Type elementType = self.getType().cast<RankedTensorType>().getElementType();
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType, elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true)); /*Negative=*/true));
Value initValue = Value initValue =
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr); rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
@ -335,7 +335,7 @@ public:
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType, elementType,
APFloat::getInf( APFloat::getInf(
elementType.cast<mlir::FloatType>().getFloatSemantics(), cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true)); /*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>( if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
@ -416,7 +416,7 @@ public:
// `maxpool2d` contains the result of maxpool2d operation over the input. // `maxpool2d` contains the result of maxpool2d operation over the input.
auto smallestFPValueAttr = rewriter.getFloatAttr( auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType, elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true)); /*Negative=*/true));
Value maxPool2d, paddedInput; Value maxPool2d, paddedInput;
SmallVector<Value, 4> outTensorShape; SmallVector<Value, 4> outTensorShape;
@ -555,7 +555,7 @@ public:
self.getType().cast<RankedTensorType>().getElementType(); self.getType().cast<RankedTensorType>().getElementType();
Type resultType = typeConverter->convertType(op.getType()); Type resultType = typeConverter->convertType(op.getType());
Type resultElementType = Type resultElementType =
resultType.cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(resultType).getElementType();
bool ceilMode; bool ceilMode;
SmallVector<Value, Dim> kernelSizeIntValues; SmallVector<Value, Dim> kernelSizeIntValues;
@ -615,9 +615,9 @@ public:
/*iteratorTypes=*/iteratorTypesAvg, /*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
Value avg; Value avg;
if (resultElementType.isa<mlir::IntegerType>()) if (isa<mlir::IntegerType>(resultElementType))
avg = b.create<arith::DivSIOp>(loc, args[0], divisor); avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
else if (resultElementType.isa<mlir::FloatType>()) else if (isa<mlir::FloatType>(resultElementType))
avg = b.create<arith::DivFOp>(loc, args[0], divisor); avg = b.create<arith::DivFOp>(loc, args[0], divisor);
b.create<linalg::YieldOp>(loc, avg); b.create<linalg::YieldOp>(loc, avg);
}) })
@ -707,7 +707,7 @@ public:
Type auxTensorElementType = auxTensorType.getElementType(); Type auxTensorElementType = auxTensorType.getElementType();
auto smallestFPValueAttr = rewriter.getFloatAttr( auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType, elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true)); /*Negative=*/true));
buffVal = rewriter.create<arith::ConstantOp>(loc, elementType, buffVal = rewriter.create<arith::ConstantOp>(loc, elementType,
smallestFPValueAttr); smallestFPValueAttr);

View File

@ -130,7 +130,7 @@ public:
RankedTensorType resultType = self.getType().cast<RankedTensorType>(); RankedTensorType resultType = self.getType().cast<RankedTensorType>();
Type elemTy = resultType.getElementType(); Type elemTy = resultType.getElementType();
if (!elemTy.isa<mlir::FloatType>()) if (!isa<mlir::FloatType>(elemTy))
return rewriter.notifyMatchFailure(op, "This op only support float type"); return rewriter.notifyMatchFailure(op, "This op only support float type");
if (!generator.getType().isa<Torch::NoneType>()) if (!generator.getType().isa<Torch::NoneType>())

View File

@ -70,7 +70,7 @@ public:
input.getType().template cast<RankedTensorType>(); input.getType().template cast<RankedTensorType>();
Type idxElementType = Type idxElementType =
getElementTypeOrSelf(typec->convertType(idxResultType)); getElementTypeOrSelf(typec->convertType(idxResultType));
if (!idxElementType.isa<IntegerType>()) if (!isa<IntegerType>(idxElementType))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, opName + " to linalg.* requires integer-like result type"); op, opName + " to linalg.* requires integer-like result type");
@ -89,8 +89,8 @@ public:
Type inElementType = inputType.getElementType(); Type inElementType = inputType.getElementType();
bool isUnsigned = false; bool isUnsigned = false;
if (!inElementType.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(inElementType)) {
if (inElementType.isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(inElementType)) {
auto integerTy = op.getSelf() auto integerTy = op.getSelf()
.getType() .getType()
.template cast<BaseTensorType>() .template cast<BaseTensorType>()
@ -121,22 +121,21 @@ public:
loc, getAsOpFoldResult(resultShape), inElementType); loc, getAsOpFoldResult(resultShape), inElementType);
Value fillValue; Value fillValue;
if (inElementType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(inElementType)) {
fillValue = rewriter.create<arith::ConstantOp>( fillValue = rewriter.create<arith::ConstantOp>(
loc, loc, rewriter.getFloatAttr(
rewriter.getFloatAttr(
inElementType, inElementType,
APFloat::getInf( APFloat::getInf(
inElementType.cast<mlir::FloatType>().getFloatSemantics(), cast<mlir::FloatType>(inElementType).getFloatSemantics(),
/*Negative=*/isMax))); /*Negative=*/isMax)));
} else if (!isUnsigned) { } else if (!isUnsigned) {
auto width = inElementType.cast<mlir::IntegerType>().getWidth(); auto width = cast<mlir::IntegerType>(inElementType).getWidth();
auto init = isMax ? APSInt::getSignedMinValue(width) auto init = isMax ? APSInt::getSignedMinValue(width)
: APSInt::getSignedMaxValue(width); : APSInt::getSignedMaxValue(width);
fillValue = rewriter.create<arith::ConstantOp>( fillValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inElementType, init)); loc, rewriter.getIntegerAttr(inElementType, init));
} else if (isUnsigned) { } else if (isUnsigned) {
auto width = inElementType.cast<mlir::IntegerType>().getWidth(); auto width = cast<mlir::IntegerType>(inElementType).getWidth();
auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width); auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width);
fillValue = rewriter.create<arith::ConstantOp>( fillValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inElementType, init)); loc, rewriter.getIntegerAttr(inElementType, init));
@ -180,7 +179,7 @@ public:
rewriter.create<linalg::IndexOp>(loc, dim)); rewriter.create<linalg::IndexOp>(loc, dim));
Value resultVal, predicate; Value resultVal, predicate;
if (inElementType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(inElementType)) {
arith::CmpFPredicate predType; arith::CmpFPredicate predType;
if (isMax) { if (isMax) {
predType = arith::CmpFPredicate::OGT; predType = arith::CmpFPredicate::OGT;
@ -300,21 +299,21 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType)); return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
if (isa<AtenProdDimIntOp>(op)) { if (isa<AtenProdDimIntOp>(op)) {
if (elementType.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(elementType))
return b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 1.0)); return b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 1.0));
else if (elementType.isa<mlir::IntegerType>()) else if (isa<mlir::IntegerType>(elementType))
return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(elementType, 1)); return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(elementType, 1));
} }
if (isa<AtenMaxOp>(op)) { if (isa<AtenMaxOp>(op)) {
if (elementType.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(elementType))
return b.create<arith::ConstantOp>( return b.create<arith::ConstantOp>(
loc, b.getFloatAttr( loc, b.getFloatAttr(
elementType, elementType,
APFloat::getInf( APFloat::getInf(
elementType.cast<mlir::FloatType>().getFloatSemantics(), cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true))); /*Negative=*/true)));
else if (elementType.isa<mlir::IntegerType>() && else if (isa<mlir::IntegerType>(elementType) &&
elementType.getIntOrFloatBitWidth() != 8) elementType.getIntOrFloatBitWidth() != 8)
return b.create<arith::ConstantOp>( return b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(elementType, loc, b.getIntegerAttr(elementType,
@ -323,14 +322,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
} }
if (isa<AtenMinOp>(op)) { if (isa<AtenMinOp>(op)) {
if (elementType.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(elementType))
return b.create<arith::ConstantOp>( return b.create<arith::ConstantOp>(
loc, b.getFloatAttr( loc, b.getFloatAttr(
elementType, elementType,
APFloat::getInf( APFloat::getInf(
elementType.cast<mlir::FloatType>().getFloatSemantics(), cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/false))); /*Negative=*/false)));
else if (elementType.isa<mlir::IntegerType>() && else if (isa<mlir::IntegerType>(elementType) &&
elementType.getIntOrFloatBitWidth() != 8) elementType.getIntOrFloatBitWidth() != 8)
return b.create<arith::ConstantOp>( return b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(elementType, loc, b.getIntegerAttr(elementType,
@ -359,25 +358,25 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
Value self = Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1]; Value result = payloadArgs[1];
if (resultElementType.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::AddFOp>(loc, self, result); return b.create<arith::AddFOp>(loc, self, result);
else if (resultElementType.isa<mlir::IntegerType>()) else if (isa<mlir::IntegerType>(resultElementType))
return b.create<arith::AddIOp>(loc, self, result); return b.create<arith::AddIOp>(loc, self, result);
} else if (isa<AtenProdDimIntOp>(op)) { } else if (isa<AtenProdDimIntOp>(op)) {
Value self = Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1]; Value result = payloadArgs[1];
if (resultElementType.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::MulFOp>(loc, self, result); return b.create<arith::MulFOp>(loc, self, result);
else if (resultElementType.isa<mlir::IntegerType>()) else if (isa<mlir::IntegerType>(resultElementType))
return b.create<arith::MulIOp>(loc, self, result); return b.create<arith::MulIOp>(loc, self, result);
} else if (auto max = dyn_cast<AtenMaxOp>(op)) { } else if (auto max = dyn_cast<AtenMaxOp>(op)) {
Value self = Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1]; Value result = payloadArgs[1];
if (resultElementType.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::MaximumFOp>(loc, self, result); return b.create<arith::MaximumFOp>(loc, self, result);
else if (resultElementType.isa<mlir::IntegerType>()) { else if (isa<mlir::IntegerType>(resultElementType)) {
IntegerType intType = max.getSelf() IntegerType intType = max.getSelf()
.getType() .getType()
.cast<BaseTensorType>() .cast<BaseTensorType>()
@ -392,9 +391,9 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
Value self = Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1]; Value result = payloadArgs[1];
if (resultElementType.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::MinimumFOp>(loc, self, result); return b.create<arith::MinimumFOp>(loc, self, result);
else if (resultElementType.isa<mlir::IntegerType>()) { else if (isa<mlir::IntegerType>(resultElementType)) {
IntegerType intType = min.getSelf() IntegerType intType = min.getSelf()
.getType() .getType()
.cast<BaseTensorType>() .cast<BaseTensorType>()
@ -626,10 +625,10 @@ private:
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) || if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
isa<AtenNormScalarOp>(op)) && isa<AtenNormScalarOp>(op)) &&
!elemType.isa<mlir::FloatType>()) !isa<mlir::FloatType>(elemType))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only float types are valid for vector norm ops"); op, "only float types are valid for vector norm ops");
if (isa<AtenAllDimOp>(op) && elemType.isa<mlir::IntegerType>() && if (isa<AtenAllDimOp>(op) && isa<mlir::IntegerType>(elemType) &&
elemType.getIntOrFloatBitWidth() == 8) elemType.getIntOrFloatBitWidth() == 8)
return rewriter.notifyMatchFailure(op, "uint8 is not supported"); return rewriter.notifyMatchFailure(op, "uint8 is not supported");

View File

@ -100,7 +100,7 @@ public:
} }
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<RankedTensorType>().getElementType(); Type elementType = cast<RankedTensorType>(newResultType).getElementType();
Value castedValue = Value castedValue =
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType); convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
@ -553,7 +553,7 @@ public:
// The size of the result is calculated as follows: // The size of the result is calculated as follows:
// ceil((end - start)/step) // ceil((end - start)/step)
Value resultShape; Value resultShape;
if (dtype.isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(dtype)) {
Value subOut = rewriter.create<arith::SubIOp>(loc, end, start); Value subOut = rewriter.create<arith::SubIOp>(loc, end, start);
resultShape = rewriter.create<arith::CeilDivSIOp>(loc, subOut, step); resultShape = rewriter.create<arith::CeilDivSIOp>(loc, subOut, step);
} else { } else {
@ -585,7 +585,7 @@ public:
index = castIndexToInt64(b, loc, index); index = castIndexToInt64(b, loc, index);
index = convertScalarToDtype(b, loc, index, dtype); index = convertScalarToDtype(b, loc, index, dtype);
Value mulOut, result; Value mulOut, result;
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
mulOut = b.create<arith::MulFOp>(loc, step, index); mulOut = b.create<arith::MulFOp>(loc, step, index);
result = b.create<arith::AddFOp>(loc, start, mulOut); result = b.create<arith::AddFOp>(loc, start, mulOut);
} else { } else {

View File

@ -35,16 +35,16 @@ using namespace mlir::torch::Torch;
template <typename elementType> static bool hasElementType(Value tensor) { template <typename elementType> static bool hasElementType(Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>(); auto tensorType = tensor.getType().cast<RankedTensorType>();
Type tensorElementType = tensorType.getElementType(); Type tensorElementType = tensorType.getElementType();
return tensorElementType.isa<elementType>(); return isa<elementType>(tensorElementType);
} }
template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred, template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred,
arith::CmpIPredicate ispred> arith::CmpIPredicate ispred>
static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
Value lhs, Value rhs) { Value lhs, Value rhs) {
if (type.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(type))
return b.create<arith::CmpFOp>(loc, fpred, lhs, rhs); return b.create<arith::CmpFOp>(loc, fpred, lhs, rhs);
if (IntegerType intType = type.dyn_cast<mlir::IntegerType>()) { if (IntegerType intType = dyn_cast<mlir::IntegerType>(type)) {
if (intType.isUnsigned()) if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs); return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
if (intType.isSigned()) if (intType.isSigned())
@ -319,7 +319,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(bitwiseAndScalar.getType()) Type dtype = converter->convertType(bitwiseAndScalar.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!dtype.isa<mlir::IntegerType>()) { if (!isa<mlir::IntegerType>(dtype)) {
bitwiseAndScalar.emitError( bitwiseAndScalar.emitError(
"bitwise_and.Scalar does not support non-integer input dtype."); "bitwise_and.Scalar does not support non-integer input dtype.");
return nullptr; return nullptr;
@ -371,7 +371,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType()) Type dtype = converter->convertType(bitwiseRightShiftTensor.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!dtype.isa<mlir::IntegerType>()) { if (!isa<mlir::IntegerType>(dtype)) {
bitwiseRightShiftTensor.emitError( bitwiseRightShiftTensor.emitError(
"Bitwise_Right_Shift op does not support non-integer input dtype."); "Bitwise_Right_Shift op does not support non-integer input dtype.");
return nullptr; return nullptr;
@ -385,7 +385,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType()) Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!dtype.isa<mlir::IntegerType>()) { if (!isa<mlir::IntegerType>(dtype)) {
bitwiseLeftShiftTensor.emitError( bitwiseLeftShiftTensor.emitError(
"Bitwise_Left_Shift op does not support non-integer input dtype."); "Bitwise_Left_Shift op does not support non-integer input dtype.");
return nullptr; return nullptr;
@ -623,7 +623,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
/*srcOriginalDtype=*/std::nullopt, /*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType); /*dstOriginalDtype=*/resultElementType);
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha); Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::AddFOp>(loc, lhs, scaled); return b.create<arith::AddFOp>(loc, lhs, scaled);
} else { } else {
@ -647,7 +647,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
/*srcOriginalDtype=*/std::nullopt, /*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType, /*dstOriginalDtype=*/resultElementType,
/*originalScalar=*/sub.getAlpha()); /*originalScalar=*/sub.getAlpha());
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha); Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::SubFOp>(loc, lhs, scaled); return b.create<arith::SubFOp>(loc, lhs, scaled);
} else { } else {
@ -664,10 +664,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value alpha = convertScalarToDtype( Value alpha = convertScalarToDtype(
b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(),
/*dstOriginalDtype=*/dtype); /*dstOriginalDtype=*/dtype);
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
Value mult = b.create<arith::MulFOp>(loc, other, alpha); Value mult = b.create<arith::MulFOp>(loc, other, alpha);
return b.create<arith::SubFOp>(loc, self, mult); return b.create<arith::SubFOp>(loc, self, mult);
} else if (dtype.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(dtype)) {
Value mult = b.create<arith::MulIOp>(loc, other, alpha); Value mult = b.create<arith::MulIOp>(loc, other, alpha);
return b.create<arith::SubIOp>(loc, self, mult); return b.create<arith::SubIOp>(loc, self, mult);
} }
@ -690,10 +690,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype, Value alpha = convertScalarToDtype(b, loc, operands[2], dtype,
/*srcOriginalDtype=*/std::nullopt, /*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType); /*dstOriginalDtype=*/resultElementType);
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
Value mult = b.create<arith::MulFOp>(loc, other, alpha); Value mult = b.create<arith::MulFOp>(loc, other, alpha);
return b.create<arith::AddFOp>(loc, self, mult); return b.create<arith::AddFOp>(loc, self, mult);
} else if (dtype.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(dtype)) {
Value mult = b.create<arith::MulIOp>(loc, other, alpha); Value mult = b.create<arith::MulIOp>(loc, other, alpha);
return b.create<arith::AddIOp>(loc, self, mult); return b.create<arith::AddIOp>(loc, self, mult);
} }
@ -708,9 +708,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
return b.create<arith::MulFOp>(loc, lhs, rhs); return b.create<arith::MulFOp>(loc, lhs, rhs);
} else if (dtype.isa<mlir::ComplexType>()) { } else if (isa<mlir::ComplexType>(dtype)) {
return b.create<complex::MulOp>(loc, lhs, rhs); return b.create<complex::MulOp>(loc, lhs, rhs);
} else { } else {
return b.create<arith::MulIOp>(loc, lhs, rhs); return b.create<arith::MulIOp>(loc, lhs, rhs);
@ -720,7 +720,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(atan2.getType()) Type dtype = converter->convertType(atan2.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!dtype.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(dtype)) {
atan2.emitError("Atan2 requires floating point result type"); atan2.emitError("Atan2 requires floating point result type");
return nullptr; return nullptr;
} }
@ -759,9 +759,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (dtype.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(dtype))
return b.create<arith::DivFOp>(loc, lhs, rhs); return b.create<arith::DivFOp>(loc, lhs, rhs);
else if (dtype.isa<mlir::IntegerType>()) { else if (isa<mlir::IntegerType>(dtype)) {
if (dtype.isUnsignedInteger()) if (dtype.isUnsignedInteger())
return b.create<arith::DivUIOp>(loc, lhs, rhs); return b.create<arith::DivUIOp>(loc, lhs, rhs);
return b.create<arith::DivSIOp>(loc, lhs, rhs); return b.create<arith::DivSIOp>(loc, lhs, rhs);
@ -777,7 +777,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value div; Value div;
if (dtype.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(dtype))
div = b.create<arith::DivFOp>(loc, lhs, rhs); div = b.create<arith::DivFOp>(loc, lhs, rhs);
else { else {
if (dtype.isUnsignedInteger()) if (dtype.isUnsignedInteger())
@ -798,7 +798,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (roundingMode == "trunc") { if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent // "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division. // to C-style integer division.
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
Value ceil = b.create<math::CeilOp>(loc, div); Value ceil = b.create<math::CeilOp>(loc, div);
Value floor = b.create<math::FloorOp>(loc, div); Value floor = b.create<math::FloorOp>(loc, div);
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype)); Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
@ -811,7 +811,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (roundingMode == "floor") { if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to // "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator) // floor division in Python (the // operator)
if (dtype.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(dtype))
return b.create<math::FloorOp>(loc, div); return b.create<math::FloorOp>(loc, div);
else if (!dtype.isUnsignedInteger()) { else if (!dtype.isUnsignedInteger()) {
Type defaultIntToFloatType = b.getF64Type(); Type defaultIntToFloatType = b.getF64Type();
@ -831,7 +831,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) { if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
Type dtype = pow.getType().cast<ValueTensorType>().getDtype(); Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
if (!dtype.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(dtype)) {
pow.emitError("unimplemented: non-floating point dtype"); pow.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -857,7 +857,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(pow.getType()) Type dtype = converter->convertType(pow.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!dtype.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(dtype)) {
pow.emitError("unimplemented: non-floating point dtype"); pow.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -870,7 +870,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(imag.getType()) Type dtype = converter->convertType(imag.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!dtype.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(dtype)) {
imag.emitError("unimplemented: non-floating point dtype"); imag.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -882,7 +882,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(real.getType()) Type dtype = converter->convertType(real.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!dtype.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(dtype)) {
real.emitError("unimplemented: non-floating point dtype"); real.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -898,10 +898,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], otherPromoted); payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) { if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) { if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float. // TODO: Promote tensor args from integer to float.
gtScalar.emitError( gtScalar.emitError(
@ -928,10 +928,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE, return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE,
payloadArgs[0], otherPromoted); payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) { if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) { if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float. // TODO: Promote tensor args from integer to float.
geScalar.emitError( geScalar.emitError(
@ -955,7 +955,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) { if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float. // TODO: Promote tensor operand from integer to float.
eqScalar.emitError( eqScalar.emitError(
@ -971,7 +971,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) { if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float. // TODO: Promote tensor operand from integer to float.
neScalar.emitError( neScalar.emitError(
@ -989,10 +989,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share // TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
// a lot of code that can be refactored. // a lot of code that can be refactored.
if (dtype.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], otherPromoted); payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) { if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) { if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float. // TODO: Promote tensor operand from integer to float.
ltScalar.emitError( ltScalar.emitError(
@ -1017,10 +1017,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code // TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
// that can be refactored. // that can be refactored.
if (dtype.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
payloadArgs[0], otherPromoted); payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) { if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) { if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float. // TODO: Promote tensor operand from integer to float.
leScalar.emitError( leScalar.emitError(
@ -1096,14 +1096,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(clamp.getType()) Type dtype = converter->convertType(clamp.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!dtype.isa<mlir::FloatType, mlir::IntegerType>()) { if (!isa<mlir::FloatType, mlir::IntegerType>(dtype)) {
clamp.emitError("unimplement type for clamp"); clamp.emitError("unimplement type for clamp");
return nullptr; return nullptr;
} }
Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype(); Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype();
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype); bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
if (auto intTy = dstOriginalDtype.dyn_cast<IntegerType>()) { if (auto intTy = dyn_cast<IntegerType>(dstOriginalDtype)) {
isUnsigned = intTy.isUnsigned(); isUnsigned = intTy.isUnsigned();
} }
auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value { auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
@ -1112,11 +1112,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
/*dstOriginalDtype=*/dstOriginalDtype); /*dstOriginalDtype=*/dstOriginalDtype);
Value pred; Value pred;
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
auto cmp = auto cmp =
getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT; getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
pred = b.create<arith::CmpFOp>(loc, cmp, input, clamp); pred = b.create<arith::CmpFOp>(loc, cmp, input, clamp);
} else if (dtype.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(dtype)) {
auto cmp = auto cmp =
isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt; isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
if (getMax) if (getMax)
@ -1151,10 +1151,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
isMinNone = false; isMinNone = false;
auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype); auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value pred; Value pred;
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, result, pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, result,
minPromoted); minPromoted);
} else if (dtype.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(dtype)) {
pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, result, pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, result,
minPromoted); minPromoted);
} else { } else {
@ -1169,10 +1169,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
max = isMinNone ? payloadArgs[1] : payloadArgs[2]; max = isMinNone ? payloadArgs[1] : payloadArgs[2];
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
Value pred; Value pred;
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, result, pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, result,
maxPromoted); maxPromoted);
} else if (dtype.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(dtype)) {
pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, result, pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, result,
maxPromoted); maxPromoted);
} else { } else {
@ -1194,10 +1194,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value alpha = convertScalarToDtype( Value alpha = convertScalarToDtype(
b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(),
/*dstOriginalDtype=*/dtype); /*dstOriginalDtype=*/dtype);
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
Value mult = b.create<arith::MulFOp>(loc, self, alpha); Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult); return b.create<arith::SubFOp>(loc, other, mult);
} else if (dtype.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(dtype)) {
Value mult = b.create<arith::MulIOp>(loc, self, alpha); Value mult = b.create<arith::MulIOp>(loc, self, alpha);
return b.create<arith::SubIOp>(loc, other, mult); return b.create<arith::SubIOp>(loc, other, mult);
} }
@ -1211,9 +1211,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype); Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
if (dtype.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(dtype))
return b.create<arith::MulFOp>(loc, lhs, rhs); return b.create<arith::MulFOp>(loc, lhs, rhs);
if (dtype.isa<mlir::IntegerType>()) if (isa<mlir::IntegerType>(dtype))
return b.create<arith::MulIOp>(loc, lhs, rhs); return b.create<arith::MulIOp>(loc, lhs, rhs);
mulScalar.emitError("unimplemented: Only integer/float dtype supported"); mulScalar.emitError("unimplemented: Only integer/float dtype supported");
return nullptr; return nullptr;
@ -1246,7 +1246,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(divScalar.getType()) Type dtype = converter->convertType(divScalar.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!dtype.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(dtype)) {
divScalar.emitError("unimplemented: non-floating point dtype"); divScalar.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -1263,9 +1263,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value other = convertScalarToDtype(b, loc, operands[1], newResultType); Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
Value result; Value result;
if (newResultType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(newResultType)) {
result = b.create<arith::RemFOp>(loc, self, other); result = b.create<arith::RemFOp>(loc, self, other);
} else if (newResultType.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(newResultType)) {
result = b.create<arith::RemSIOp>(loc, self, other); result = b.create<arith::RemSIOp>(loc, self, other);
} else { } else {
remScalar.emitError( remScalar.emitError(
@ -1283,9 +1283,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
Value result; Value result;
if (newResultType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(newResultType)) {
result = b.create<arith::RemFOp>(loc, self, other); result = b.create<arith::RemFOp>(loc, self, other);
} else if (newResultType.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(newResultType)) {
result = b.create<arith::RemSIOp>(loc, self, other); result = b.create<arith::RemSIOp>(loc, self, other);
} else { } else {
remTensor.emitError( remTensor.emitError(
@ -1303,12 +1303,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
Value result; Value result;
if (newResultType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(newResultType)) {
Value n = b.create<arith::DivFOp>(loc, self, other); Value n = b.create<arith::DivFOp>(loc, self, other);
n = b.create<math::TruncOp>(loc, n); n = b.create<math::TruncOp>(loc, n);
Value n_y = b.create<arith::MulFOp>(loc, n, other); Value n_y = b.create<arith::MulFOp>(loc, n, other);
result = b.create<arith::SubFOp>(loc, self, n_y); result = b.create<arith::SubFOp>(loc, self, n_y);
} else if (newResultType.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(newResultType)) {
Value n = b.create<arith::DivSIOp>(loc, self, other); Value n = b.create<arith::DivSIOp>(loc, self, other);
Value n_y = b.create<arith::MulIOp>(loc, n, other); Value n_y = b.create<arith::MulIOp>(loc, n, other);
result = b.create<arith::SubIOp>(loc, self, n_y); result = b.create<arith::SubIOp>(loc, self, n_y);
@ -1349,7 +1349,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype); Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
Value predicate; Value predicate;
if (dtype.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(dtype))
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self, predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
threshold); threshold);
else else
@ -1372,7 +1372,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype)); Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value predicate; Value predicate;
if (dtype.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(dtype))
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self, predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
threshold); threshold);
else else
@ -1426,7 +1426,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type elementType = converter->convertType(bitwiseNot.getType()) Type elementType = converter->convertType(bitwiseNot.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (elementType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(elementType)) {
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype"); bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
return nullptr; return nullptr;
} }
@ -2253,7 +2253,7 @@ public:
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = input.getType().cast<RankedTensorType>();
auto inputElementType = inputType.getElementType(); auto inputElementType = inputType.getElementType();
if (!inputElementType.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(inputElementType)) {
op.emitError("Logit does not support non-floating point type"); op.emitError("Logit does not support non-floating point type");
return failure(); return failure();
} }

View File

@ -554,7 +554,7 @@ FailureOr<Type> torch_to_linalg::getBackendTypeForScalarType(
} }
Type type = *maybeType; Type type = *maybeType;
// The linalg-on-tensors backend currently expects integers to be signless. // The linalg-on-tensors backend currently expects integers to be signless.
if (auto intType = type.dyn_cast<IntegerType>()) { if (auto intType = dyn_cast<IntegerType>(type)) {
type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless); type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless);
} }
return type; return type;

View File

@ -140,11 +140,11 @@ public:
// If the target type is non-torch type, then use TypeConverter to convert // If the target type is non-torch type, then use TypeConverter to convert
// the type of the source. // the type of the source.
if (targetType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(targetType)) {
targetType = Torch::FloatType::get(op->getContext()); targetType = Torch::FloatType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion( torchArg = typeConverter->materializeSourceConversion(
rewriter, scfWhileOp.getLoc(), targetType, {to}); rewriter, scfWhileOp.getLoc(), targetType, {to});
} else if (targetType.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(targetType)) {
unsigned bitWidth = targetType.getIntOrFloatBitWidth(); unsigned bitWidth = targetType.getIntOrFloatBitWidth();
if (bitWidth == 1) if (bitWidth == 1)
targetType = Torch::BoolType::get(op->getContext()); targetType = Torch::BoolType::get(op->getContext());
@ -179,7 +179,7 @@ public:
// If the argument is a torch tensor, directly add it in the list of // If the argument is a torch tensor, directly add it in the list of
// iter args. // iter args.
if (torchType.isa<Torch::BaseTensorType>()) { if (isa<Torch::BaseTensorType>(torchType)) {
loopConditionIterArgs.push_back(torchArg); loopConditionIterArgs.push_back(torchArg);
continue; continue;
} }
@ -262,11 +262,11 @@ public:
// If the target type is non-torch type, then use TypeConverter to convert // If the target type is non-torch type, then use TypeConverter to convert
// the type of the source. // the type of the source.
if (targetType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(targetType)) {
targetType = Torch::FloatType::get(op->getContext()); targetType = Torch::FloatType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion( torchArg = typeConverter->materializeSourceConversion(
rewriter, scfForOp.getLoc(), targetType, {to}); rewriter, scfForOp.getLoc(), targetType, {to});
} else if (targetType.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(targetType)) {
unsigned bitWidth = targetType.getIntOrFloatBitWidth(); unsigned bitWidth = targetType.getIntOrFloatBitWidth();
if (bitWidth == 1) if (bitWidth == 1)
targetType = Torch::BoolType::get(op->getContext()); targetType = Torch::BoolType::get(op->getContext());

View File

@ -42,11 +42,11 @@ static Value getConstantLike(OpBuilder &b, Location loc, T constant,
Value val) { Value val) {
Type ty = getElementTypeOrSelf(val.getType()); Type ty = getElementTypeOrSelf(val.getType());
auto getAttr = [&]() -> Attribute { auto getAttr = [&]() -> Attribute {
if (ty.isa<mlir::IntegerType>()) if (isa<mlir::IntegerType>(ty))
return b.getIntegerAttr(ty, constant); return b.getIntegerAttr(ty, constant);
if (ty.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(ty))
return b.getFloatAttr(ty, constant); return b.getFloatAttr(ty, constant);
if (auto complexTy = ty.dyn_cast<mlir::ComplexType>()) if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
return complex::NumberAttr::get(complexTy, constant, 0); return complex::NumberAttr::get(complexTy, constant, 0);
llvm_unreachable("unhandled element type"); llvm_unreachable("unhandled element type");
}; };
@ -105,17 +105,17 @@ bool skipMultiplyAlpha(Value alphaValue) {
static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType, static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementType); auto constType = RankedTensorType::get({}, elementType);
if (elementType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(elementType)) {
auto constAttr = SplatElementsAttr::get( auto constAttr = SplatElementsAttr::get(
constType, constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*negative=*/false)); /*negative=*/false));
return rewriter return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr) .create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult(); .getResult();
} }
if (elementType.isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(elementType)) {
auto integerType = elementType.cast<mlir::IntegerType>(); auto integerType = cast<mlir::IntegerType>(elementType);
DenseElementsAttr constAttr; DenseElementsAttr constAttr;
if (integerType.isUnsigned()) { if (integerType.isUnsigned()) {
constAttr = SplatElementsAttr::get( constAttr = SplatElementsAttr::get(
@ -134,17 +134,17 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType, static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementType); auto constType = RankedTensorType::get({}, elementType);
if (elementType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(elementType)) {
auto constAttr = SplatElementsAttr::get( auto constAttr = SplatElementsAttr::get(
constType, constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*negative=*/true)); /*negative=*/true));
return rewriter return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr) .create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult(); .getResult();
} }
if (elementType.isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(elementType)) {
auto integerType = elementType.cast<mlir::IntegerType>(); auto integerType = cast<mlir::IntegerType>(elementType);
DenseElementsAttr constAttr; DenseElementsAttr constAttr;
if (integerType.isUnsigned()) { if (integerType.isUnsigned()) {
constAttr = SplatElementsAttr::get( constAttr = SplatElementsAttr::get(
@ -446,7 +446,7 @@ public:
op, "only support constant str rounding mode"); op, "only support constant str rounding mode");
// if trunc and int, do nothing // if trunc and int, do nothing
if (roundingMode == "trunc" && outElemTy.isa<mlir::FloatType>()) { if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) {
// "trunc" - rounds the results of the division towards zero. Equivalent // "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division. // to C-style integer division.
auto sign = rewriter.create<stablehlo::SignOp>(loc, result); auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
@ -457,7 +457,7 @@ public:
if (roundingMode == "floor") { if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to // "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator) // floor division in Python (the // operator)
if (outElemTy.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(outElemTy))
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult(); result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
else if (!outElemTy.isUnsignedInteger()) { else if (!outElemTy.isUnsignedInteger()) {
TensorType defaultIntToFloatType = TensorType defaultIntToFloatType =
@ -518,10 +518,10 @@ public:
chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonTypeAttr compareTypeAttr;
chlo::ComparisonDirectionAttr compareDirectionAttr; chlo::ComparisonDirectionAttr compareDirectionAttr;
if (lhsElemTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(lhsElemTy)) {
compareTypeAttr = chlo::ComparisonTypeAttr::get( compareTypeAttr = chlo::ComparisonTypeAttr::get(
op->getContext(), chlo::ComparisonType::FLOAT); op->getContext(), chlo::ComparisonType::FLOAT);
} else if (lhsElemTy.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(lhsElemTy)) {
compareTypeAttr = chlo::ComparisonTypeAttr::get( compareTypeAttr = chlo::ComparisonTypeAttr::get(
op->getContext(), chlo::ComparisonType::SIGNED); op->getContext(), chlo::ComparisonType::SIGNED);
} }
@ -985,14 +985,14 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = lhs.getType().cast<RankedTensorType>();
auto lhsElemTy = lhsTy.getElementType(); auto lhsElemTy = lhsTy.getElementType();
if (!lhsElemTy.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(lhsElemTy)) {
return op->emitError("only float tensor in relu op is supported"); return op->emitError("only float tensor in relu op is supported");
} }
Value zeroTensor; Value zeroTensor;
zeroTensor = getConstantLike( zeroTensor = getConstantLike(
rewriter, op->getLoc(), rewriter, op->getLoc(),
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getZero(cast<mlir::FloatType>(lhsElemTy).getFloatSemantics(),
false), false),
lhs); lhs);
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor); rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
@ -1160,7 +1160,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
rewriter.getI64IntegerAttr(feature_index)); rewriter.getI64IntegerAttr(feature_index));
output = hlo::promoteType(rewriter, op.getLoc(), output = hlo::promoteType(rewriter, op.getLoc(),
batchNormTrainingResult.getResult(0), batchNormTrainingResult.getResult(0),
outputTy.cast<TensorType>()); cast<TensorType>(outputTy));
} else { } else {
auto batchNormTrainingResult = auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>( rewriter.create<stablehlo::BatchNormTrainingOp>(
@ -1204,7 +1204,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
runningVar, rewriter.getF32FloatAttr(eps), runningVar, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index)); rewriter.getI64IntegerAttr(feature_index));
output = hlo::promoteType(rewriter, op.getLoc(), bnResult, output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
outputTy.cast<TensorType>()); cast<TensorType>(outputTy));
} else { } else {
output = rewriter.create<stablehlo::BatchNormInferenceOp>( output = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
@ -1478,7 +1478,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
->convertType(op.getType()) ->convertType(op.getType())
.cast<RankedTensorType>(); .cast<RankedTensorType>();
auto dtype = outType.getElementType(); auto dtype = outType.getElementType();
if (!dtype.isa<mlir::IntegerType>() && !dtype.isa<mlir::FloatType>()) { if (!isa<mlir::IntegerType>(dtype) && !isa<mlir::FloatType>(dtype)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: only int or float dtype supported"); op, "unimplemented: only int or float dtype supported");
} }
@ -1607,7 +1607,7 @@ LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>( auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getI64TensorAttr(elements)); loc, rewriter.getI64TensorAttr(elements));
auto outTy = getTypeConverter()->convertType(op.getType()); auto outTy = getTypeConverter()->convertType(op.getType());
auto outElemTy = outTy.cast<RankedTensorType>().getElementType(); auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
Value from = Value from =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy); hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
Value to = Value to =

View File

@ -34,14 +34,14 @@ static Value createInitialValueForGatherScatterOp(Operation *op,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto elementTy = constType.getElementType(); auto elementTy = constType.getElementType();
if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) { if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero( constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)}); /*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});

View File

@ -37,14 +37,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
// Avg pooling // Avg pooling
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
AtenCumsumOp>(op)) { AtenCumsumOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero( constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)}); /*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
@ -55,14 +55,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
// Max pooling // Max pooling
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) { if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getInf( constType,
elementTy.cast<mlir::FloatType>().getFloatSemantics(), {APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/true)}); /*negative=*/true)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, constType,

View File

@ -37,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
auto constType = RankedTensorType::get({}, elementTy); auto constType = RankedTensorType::get({}, elementTy);
if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp, if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
AtenLinalgVectorNormOp>(op)) { AtenLinalgVectorNormOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero( constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)}); /*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
@ -54,14 +54,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
} }
if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) { if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getInf( constType,
elementTy.cast<mlir::FloatType>().getFloatSemantics(), {APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/true)}); /*negative=*/true)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, constType,
@ -72,14 +72,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
} }
if (isa<AtenMinOp>(op)) { if (isa<AtenMinOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getInf( constType,
elementTy.cast<mlir::FloatType>().getFloatSemantics(), {APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)}); /*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, constType,
@ -234,7 +234,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
"only floating-point or integer datatype legalization supported"); "only floating-point or integer datatype legalization supported");
} }
// Currently, (u)int8 dtype is not supported! // Currently, (u)int8 dtype is not supported!
if (inputElemTy.isa<mlir::IntegerType>() && if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -305,7 +305,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
"Only floating-point or integer datatype legalization supported"); "Only floating-point or integer datatype legalization supported");
} }
// Currently, (u)int8 dtype is not supported // Currently, (u)int8 dtype is not supported
if (inputElemTy.isa<mlir::IntegerType>() && if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -319,7 +319,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
->convertType(op.getResult(1).getType()) ->convertType(op.getResult(1).getType())
.template cast<RankedTensorType>(); .template cast<RankedTensorType>();
Type idxElementType = idxResultType.getElementType(); Type idxElementType = idxResultType.getElementType();
if (!idxElementType.isa<mlir::IntegerType>()) { if (!isa<mlir::IntegerType>(idxElementType)) {
return op.emitError("Aten.max.dim needs integer-like result"); return op.emitError("Aten.max.dim needs integer-like result");
} }
@ -404,7 +404,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
"only floating-point or integer datatype legalization supported"); "only floating-point or integer datatype legalization supported");
} }
// Currently, (u)int8 dtype is not supported // Currently, (u)int8 dtype is not supported
if (inputElemTy.isa<mlir::IntegerType>() && if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -466,7 +466,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
"only floating-point or integer datatype legalization supported"); "only floating-point or integer datatype legalization supported");
} }
// Currently, (u)int8 dtype is not supported // Currently, (u)int8 dtype is not supported
if (inputElemTy.isa<mlir::IntegerType>() && if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -529,7 +529,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
"only floating-point or integer datatype legalization supported"); "only floating-point or integer datatype legalization supported");
} }
// Currently, (u)int8 dtype is not supported // Currently, (u)int8 dtype is not supported
if (inputElemTy.isa<mlir::IntegerType>() && if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -603,7 +603,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
} }
// Currently, (u)int8 dtype is not supported // Currently, (u)int8 dtype is not supported
if (inputElemTy.isa<mlir::IntegerType>() && if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -715,7 +715,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
} }
auto inputRank = inputType.getRank(); auto inputRank = inputType.getRank();
auto inputElemType = inputType.getElementType(); auto inputElemType = inputType.getElementType();
if (!inputElemType.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(inputElemType)) {
return op.emitError( return op.emitError(
"only float dtype allowed in input tensor of AtenFrobeniusNormDimOp"); "only float dtype allowed in input tensor of AtenFrobeniusNormDimOp");
} }
@ -830,7 +830,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
auto outElemType = outType.getElementType(); auto outElemType = outType.getElementType();
if (!outElemType.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(outElemType)) {
return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp"); return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp");
} }
@ -912,7 +912,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
op->getLoc(), blockArgumentTy, op->getLoc(), blockArgumentTy,
DenseElementsAttr::get( DenseElementsAttr::get(
blockArgumentTy, blockArgumentTy,
APFloat(outElemType.cast<mlir::FloatType>().getFloatSemantics(), 1))); APFloat(cast<mlir::FloatType>(outElemType).getFloatSemantics(), 1)));
auto reciprocalOrd = rewriter.create<stablehlo::DivOp>( auto reciprocalOrd = rewriter.create<stablehlo::DivOp>(
op->getLoc(), blockArgumentTy, constantOne, ord); op->getLoc(), blockArgumentTy, constantOne, ord);
auto output = rewriter.create<chlo::BroadcastPowOp>( auto output = rewriter.create<chlo::BroadcastPowOp>(

View File

@ -144,12 +144,12 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Unable to extract the scalar constant"); "Unable to extract the scalar constant");
if (dtype.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(dtype)) {
tosaTensor = tosa::getConstTensor<float>(rewriter, op, tosaTensor = tosa::getConstTensor<float>(rewriter, op,
(isFloat ? doubleValue : intValue), (isFloat ? doubleValue : intValue),
dshape, dtype) dshape, dtype)
.value(); .value();
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) { } else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
auto w = intType.getWidth(); auto w = intType.getWidth();
if (w != 1 && w != 32 && w != 64) if (w != 1 && w != 32 && w != 64)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
@ -261,7 +261,7 @@ public:
} }
Type rhsAlphaMulElemType; Type rhsAlphaMulElemType;
if (outElemTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(outElemTy)) {
rhsAlphaMulElemType = outElemTy; rhsAlphaMulElemType = outElemTy;
} else { } else {
// if output type is 64, input type should also be 32 // if output type is 64, input type should also be 32
@ -355,7 +355,7 @@ public:
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() || std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
std::is_same<AtenOpT, AtenBitwiseOrTensorOp>() || std::is_same<AtenOpT, AtenBitwiseOrTensorOp>() ||
std::is_same<AtenOpT, AtenBitwiseXorTensorOp>(); std::is_same<AtenOpT, AtenBitwiseXorTensorOp>();
if (lhsElemTy.isa<mlir::FloatType>() && isBitwiseOp) { if (isa<mlir::FloatType>(lhsElemTy) && isBitwiseOp) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"For bitwise operators, only integer " "For bitwise operators, only integer "
"datatype legalization is supported"); "datatype legalization is supported");
@ -442,8 +442,7 @@ public:
rhsTensor = rhsType ? rhs : rhsAsTensor; rhsTensor = rhsType ? rhs : rhsAsTensor;
} }
if (outElemTy.isa<mlir::FloatType>() || if (isa<mlir::FloatType>(outElemTy) || isa<mlir::IntegerType>(outElemTy)) {
outElemTy.isa<mlir::IntegerType>()) {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template cast<TensorType>(); .template cast<TensorType>();
@ -1454,7 +1453,7 @@ public:
SmallVector<int64_t> matmulOutputShape( SmallVector<int64_t> matmulOutputShape(
{matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]}); {matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]});
Type outputElemTy; Type outputElemTy;
if (lhsElemTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(lhsElemTy)) {
outputElemTy = lhsElemTy; outputElemTy = lhsElemTy;
} else { // qint8 emits i32 matmul output } else { // qint8 emits i32 matmul output
outputElemTy = rewriter.getIntegerType(32); outputElemTy = rewriter.getIntegerType(32);
@ -1898,7 +1897,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and // TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt to // accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
// define a 48-bit int. // define a 48-bit int.
if (inputElemTy.isa<quant::QuantizedType>()) { if (isa<quant::QuantizedType>(inputElemTy)) {
SmallVector<int32_t> zeroVec(weightShape[0], 0); SmallVector<int32_t> zeroVec(weightShape[0], 0);
bias = tosa::getConstTensor<int32_t>( bias = tosa::getConstTensor<int32_t>(
rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])}) rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])})
@ -1915,7 +1914,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
op, "Bias provided but not a ranked tensor"); op, "Bias provided but not a ranked tensor");
} }
auto biasElemTy = auto biasElemTy =
inputElemTy.isa<mlir::FloatType>() ? inputElemTy : rewriter.getI32Type(); isa<mlir::FloatType>(inputElemTy) ? inputElemTy : rewriter.getI32Type();
int64_t groups; int64_t groups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) { if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) {
@ -2098,7 +2097,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
.getResult(); .getResult();
Value rescaledResult = transposedOutput; Value rescaledResult = transposedOutput;
if (inputElemTy.isa<quant::QuantizedType>()) { if (isa<quant::QuantizedType>(inputElemTy)) {
rescaledResult = tosa::buildRescaleOpConvOutput( rescaledResult = tosa::buildRescaleOpConvOutput(
rewriter, op, transposedOutput, inputTy, weightTy, outputTy); rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
} }
@ -2230,7 +2229,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
if (toBcastType.getRank() > 1) if (toBcastType.getRank() > 1)
return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1"); return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1");
RankedTensorType outTensorType = outType.cast<RankedTensorType>(); RankedTensorType outTensorType = cast<RankedTensorType>(outType);
SmallVector<int64_t> newShape = { SmallVector<int64_t> newShape = {
makeShapeTorchCompatible(toBcastType.getShape())[0]}; makeShapeTorchCompatible(toBcastType.getShape())[0]};
for (auto i = 2; i < outTensorType.getRank(); ++i) for (auto i = 2; i < outTensorType.getRank(); ++i)
@ -2677,7 +2676,7 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
op, "Only floating-point or integer datatype legalization supported"); op, "Only floating-point or integer datatype legalization supported");
// Integer types with width > 32 are not supported // Integer types with width > 32 are not supported
auto selfIntType = selfElemTy.dyn_cast<IntegerType>(); auto selfIntType = dyn_cast<IntegerType>(selfElemTy);
if (selfIntType && selfIntType.getWidth() > 32) { if (selfIntType && selfIntType.getWidth() > 32) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Integer types with width greater than 32 are not supported"); op, "Integer types with width greater than 32 are not supported");
@ -2956,7 +2955,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
auto selfElemTy = selfType.getElementType(); auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(selfElemTy)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported"); op, "Only floating-point datatype legalization supported");
} }
@ -2993,7 +2992,7 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
auto selfElemTy = selfType.getElementType(); auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(selfElemTy)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported"); op, "Only floating-point datatype legalization supported");
} }
@ -3057,7 +3056,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
} }
// Integer types with width > 32 are not supported // Integer types with width > 32 are not supported
auto selfIntType = selfElemTy.dyn_cast<IntegerType>(); auto selfIntType = dyn_cast<IntegerType>(selfElemTy);
if (selfIntType && selfIntType.getWidth() > 32) { if (selfIntType && selfIntType.getWidth() > 32) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Integer types with width greater than 32 are not supported"); op, "Integer types with width greater than 32 are not supported");
@ -4235,7 +4234,7 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
"Unable to extract the scalar constant"); "Unable to extract the scalar constant");
auto outElemTy = resultType.getElementType(); auto outElemTy = resultType.getElementType();
if (outElemTy.isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(outElemTy)) {
rewriter.replaceOpWithNewOp<tosa::ConstOp>( rewriter.replaceOpWithNewOp<tosa::ConstOp>(
op, resultType, DenseElementsAttr::get(resultType, {intValue})); op, resultType, DenseElementsAttr::get(resultType, {intValue}));
} else if (outElemTy.isF64()) { } else if (outElemTy.isF64()) {
@ -4383,7 +4382,7 @@ LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
auto divTensor = self; auto divTensor = self;
// tosa::DivOp only supports int // tosa::DivOp only supports int
if (outElemTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(outElemTy)) {
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>( auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), otherTensor.getType(), otherTensor); op.getLoc(), otherTensor.getType(), otherTensor);
divTensor = rewriter.create<tosa::MulOp>( divTensor = rewriter.create<tosa::MulOp>(

View File

@ -119,7 +119,7 @@ tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
Value lhs, Value rhs) { Value lhs, Value rhs) {
auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType(); auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType();
auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType(); auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType();
if (lhsElemTy.isa<mlir::FloatType>() || rhsElemTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(lhsElemTy) || isa<mlir::FloatType>(rhsElemTy)) {
(void)rewriter.notifyMatchFailure(op, (void)rewriter.notifyMatchFailure(op,
"tosa.div only supports integer type"); "tosa.div only supports integer type");
} }
@ -213,7 +213,7 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op, std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
Type outType, Value paramsValue, Type outType, Value paramsValue,
Value indicesValue) { Value indicesValue) {
auto resultType = outType.dyn_cast<ShapedType>(); auto resultType = dyn_cast<ShapedType>(outType);
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>(); auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>(); auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
@ -419,7 +419,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
Operation *op, Type outType, Operation *op, Type outType,
Value paramsValue, Value indicesValue, Value paramsValue, Value indicesValue,
Value fillValues) { Value fillValues) {
auto resultType = outType.dyn_cast<ShapedType>(); auto resultType = dyn_cast<ShapedType>(outType);
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>(); auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>(); auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
auto fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>(); auto fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
@ -981,7 +981,7 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt; return std::nullopt;
Type elemType = output_type.getElementType(); Type elemType = output_type.getElementType();
if (!elemType.isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(elemType)) {
op->emitOpError("Only floating-point datatype legalization supported for " op->emitOpError("Only floating-point datatype legalization supported for "
"AtenLinalgVectorNorm op"); "AtenLinalgVectorNorm op");
return std::nullopt; return std::nullopt;

View File

@ -154,7 +154,7 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
// Create a zero constant tensor of the desired type and shape. // Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter, std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type) { Operation *op, Type type) {
RankedTensorType resultType = type.dyn_cast<RankedTensorType>(); RankedTensorType resultType = dyn_cast<RankedTensorType>(type);
if (!resultType) { if (!resultType) {
(void)rewriter.notifyMatchFailure(op, "not ranked tensor type"); (void)rewriter.notifyMatchFailure(op, "not ranked tensor type");
@ -167,7 +167,7 @@ std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Attribute zeroAttr = rewriter.getZeroAttr(zeroType); Attribute zeroAttr = rewriter.getZeroAttr(zeroType);
return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zeroType, return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zeroType,
zeroAttr.cast<ElementsAttr>()) cast<ElementsAttr>(zeroAttr))
.getResult(); .getResult();
} }
@ -312,7 +312,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result) { Value src, Type destType, Value &result) {
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType(); Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType();
Type destElemTy = destType.dyn_cast<TensorType>().getElementType(); Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
if (failed(checkValidityOfCast(srcElemTy, destElemTy))) if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -392,7 +392,7 @@ LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
// Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time // Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time
// FP16 is supported, the accumulator type can be selected based on trade-off // FP16 is supported, the accumulator type can be selected based on trade-off
// between performance and accuracy. Set to FP32 by default. // between performance and accuracy. Set to FP32 by default.
accType = inputETy.isa<FloatType>() accType = isa<FloatType>(inputETy)
? mlir::TypeAttr::get(rewriter.getF32Type()) ? mlir::TypeAttr::get(rewriter.getF32Type())
: mlir::TypeAttr::get(rewriter.getIntegerType(32)); : mlir::TypeAttr::get(rewriter.getIntegerType(32));

View File

@ -27,9 +27,9 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
// TODO: Remove this check but use a separate verification pass to verify the // TODO: Remove this check but use a separate verification pass to verify the
// invariants expected by later passes. // invariants expected by later passes.
auto isValidLinalgType = [](Type type) { auto isValidLinalgType = [](Type type) {
if (type.isa<NonValueTensorType>()) if (isa<NonValueTensorType>(type))
return false; return false;
auto tensor = type.dyn_cast<ValueTensorType>(); auto tensor = dyn_cast<ValueTensorType>(type);
return !tensor || return !tensor ||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>(); tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
}; };
@ -43,8 +43,8 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) { LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) {
Type type = v.getType(); Type type = v.getType();
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() || if (isa<OptionalType>(type) || isa<Torch::NoneType>(type) ||
type.isa<mlir::NoneType>()) isa<mlir::NoneType>(type))
return rewriter.notifyMatchFailure(op, "unimplemented None type arg"); return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
return success(); return success();
} }
@ -104,7 +104,7 @@ void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
Type lhsType = lhsDim.getType(); Type lhsType = lhsDim.getType();
Type rhsType = rhsDim.getType(); Type rhsType = rhsDim.getType();
auto checkIntOrIndex = [](Type type) { auto checkIntOrIndex = [](Type type) {
assert((type.isa<IntegerType>() || type.isa<IndexType>()) && assert((isa<IntegerType>(type) || isa<IndexType>(type)) &&
"must be either integer or index type"); "must be either integer or index type");
}; };
checkIntOrIndex(lhsType); checkIntOrIndex(lhsType);
@ -198,13 +198,13 @@ Value getTensorSize(OpBuilder &b, Location loc, Value tensor) {
// Creates a constant of type `elemType` with value `val`. // Creates a constant of type `elemType` with value `val`.
Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) { Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) {
TypedAttr attr = {}; TypedAttr attr = {};
if (elemType.isa<mlir::FloatType>()) if (isa<mlir::FloatType>(elemType))
attr = b.getFloatAttr(elemType, val); attr = b.getFloatAttr(elemType, val);
if (elemType.isa<mlir::IndexType>()) if (isa<mlir::IndexType>(elemType))
attr = b.getIndexAttr(val); attr = b.getIndexAttr(val);
if (elemType.isa<mlir::IntegerType>()) if (isa<mlir::IntegerType>(elemType))
attr = b.getIntegerAttr( attr = b.getIntegerAttr(elemType,
elemType, APInt(elemType.cast<IntegerType>().getWidth(), val)); APInt(cast<IntegerType>(elemType).getWidth(), val));
if (!attr) if (!attr)
return nullptr; return nullptr;
return b.create<arith::ConstantOp>(loc, elemType, attr); return b.create<arith::ConstantOp>(loc, elemType, attr);
@ -264,7 +264,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return scalar; return scalar;
auto isByteOrChar = [](Type type) { auto isByteOrChar = [](Type type) {
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) { if (auto integerTy = dyn_cast<mlir::IntegerType>(type)) {
return integerTy.getWidth() == 8; return integerTy.getWidth() == 8;
} }
return false; return false;
@ -303,10 +303,10 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
if (dtype.isSignlessInteger(1)) { if (dtype.isSignlessInteger(1)) {
Type scalarType = scalar.getType(); Type scalarType = scalar.getType();
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(scalarType)); Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(scalarType));
if (scalarType.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(scalarType)) {
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, scalar, return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, scalar,
cstZero); cstZero);
} else if (scalarType.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(scalarType)) {
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, scalar, return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, scalar,
cstZero); cstZero);
} else { } else {
@ -317,14 +317,14 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
} }
} }
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) { if (auto dtypeFloat = dyn_cast<mlir::FloatType>(dtype)) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) { if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType)) {
if (scalarFloat.getWidth() > dtypeFloat.getWidth()) if (scalarFloat.getWidth() > dtypeFloat.getWidth())
return b.create<arith::TruncFOp>(loc, dtype, scalar); return b.create<arith::TruncFOp>(loc, dtype, scalar);
// Only scalarFloat width < dtypeFloat width can reach here. // Only scalarFloat width < dtypeFloat width can reach here.
return b.create<arith::ExtFOp>(loc, dtype, scalar); return b.create<arith::ExtFOp>(loc, dtype, scalar);
} }
assert(scalarType.isa<mlir::IntegerType>()); assert(isa<mlir::IntegerType>(scalarType));
if (scalarType.isSignlessInteger(1) || if (scalarType.isSignlessInteger(1) ||
(srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger())) (srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger()))
return b.create<arith::UIToFPOp>(loc, dtype, scalar); return b.create<arith::UIToFPOp>(loc, dtype, scalar);
@ -333,11 +333,11 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return b.create<arith::SIToFPOp>(loc, dtype, scalar); return b.create<arith::SIToFPOp>(loc, dtype, scalar);
} }
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) { if (auto dtypeInteger = dyn_cast<mlir::IntegerType>(dtype)) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType))
return b.create<arith::FPToSIOp>(loc, dtype, scalar); return b.create<arith::FPToSIOp>(loc, dtype, scalar);
assert(scalarType.isa<mlir::IntegerType>()); assert(isa<mlir::IntegerType>(scalarType));
auto scalarInteger = scalarType.cast<mlir::IntegerType>(); auto scalarInteger = cast<mlir::IntegerType>(scalarType);
if (scalarInteger.getWidth() > dtypeInteger.getWidth()) if (scalarInteger.getWidth() > dtypeInteger.getWidth())
return b.create<arith::TruncIOp>(loc, dtype, scalar); return b.create<arith::TruncIOp>(loc, dtype, scalar);
if (scalarType.isSignlessInteger(1) || if (scalarType.isSignlessInteger(1) ||

View File

@ -49,7 +49,7 @@ allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp,
size_t resultIndex = en.index(); size_t resultIndex = en.index();
Type resultType = en.value(); Type resultType = en.value();
auto tensorType = resultType.dyn_cast<RankedTensorType>(); auto tensorType = dyn_cast<RankedTensorType>(resultType);
if (tensorType == nullptr) { if (tensorType == nullptr) {
tmtensorOp.emitOpError() tmtensorOp.emitOpError()
<< "tensor to buffer conversion expects ranked tensor results"; << "tensor to buffer conversion expects ranked tensor results";

View File

@ -100,10 +100,12 @@ void TorchDialect::initialize() {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc" #include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
>(); >();
addTypes< addTypes<
#define GET_TYPEDEF_LIST #define GET_TYPEDEF_LIST
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
>(); >();
addInterfaces<TorchInlinerInterface>(); addInterfaces<TorchInlinerInterface>();
} }
@ -144,35 +146,34 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
Operation *TorchDialect::materializeConstant(OpBuilder &builder, Operation *TorchDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type, Attribute value, Type type,
Location loc) { Location loc) {
if (auto integerType = type.dyn_cast<Torch::IntType>()) if (auto integerType = dyn_cast<Torch::IntType>(type))
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>()); return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value));
if (auto floatType = type.dyn_cast<Torch::FloatType>()) if (auto floatType = dyn_cast<Torch::FloatType>(type))
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>()); return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value));
if (auto numberType = type.dyn_cast<Torch::NumberType>()) { if (auto numberType = dyn_cast<Torch::NumberType>(type)) {
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) { if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
return builder.create<Torch::ConstantNumberOp>(loc, floatValue); return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) { } else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
return builder.create<Torch::ConstantNumberOp>(loc, intValue); return builder.create<Torch::ConstantNumberOp>(loc, intValue);
} }
} }
if (type.isa<Torch::BoolType>()) { if (isa<Torch::BoolType>(type)) {
return builder.create<Torch::ConstantBoolOp>(loc, return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value));
value.cast<IntegerAttr>());
} }
if (type.isa<Torch::NoneType>()) if (isa<Torch::NoneType>(type))
return builder.create<ConstantNoneOp>(loc); return builder.create<ConstantNoneOp>(loc);
if (auto stringAttr = value.dyn_cast<StringAttr>()) if (auto stringAttr = dyn_cast<StringAttr>(value))
return builder.create<ConstantStrOp>(loc, stringAttr); return builder.create<ConstantStrOp>(loc, stringAttr);
if (auto elementsAttr = value.dyn_cast<ElementsAttr>()) { if (auto elementsAttr = dyn_cast<ElementsAttr>(value)) {
// Only !torch.vtensor can be constant folded. !torch.tensor has // Only !torch.vtensor can be constant folded. !torch.tensor has
// non-trivial aliasing semantics which prevent deduplicating it. // non-trivial aliasing semantics which prevent deduplicating it.
assert(type.isa<ValueTensorType>() && "should be a vtensor type!"); assert(isa<ValueTensorType>(type) && "should be a vtensor type!");
return builder.create<ValueTensorLiteralOp>(loc, elementsAttr); return builder.create<ValueTensorLiteralOp>(loc, elementsAttr);
} }

View File

@ -41,9 +41,8 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
return value; return value;
// If the type is a tensor, then adjust the static information. // If the type is a tensor, then adjust the static information.
if ((type.isa<ValueTensorType>() && desiredType.isa<ValueTensorType>()) || if ((isa<ValueTensorType>(type) && isa<ValueTensorType>(desiredType)) ||
(type.isa<NonValueTensorType>() && (isa<NonValueTensorType>(type) && isa<NonValueTensorType>(desiredType))) {
desiredType.isa<NonValueTensorType>())) {
Value adjusted = builder.create<TensorStaticInfoCastOp>(value.getLoc(), Value adjusted = builder.create<TensorStaticInfoCastOp>(value.getLoc(),
desiredType, value); desiredType, value);
return adjusted; return adjusted;
@ -90,7 +89,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
// then we do the copy by going to a value tensor and back. // then we do the copy by going to a value tensor and back.
if (tensor.getType().isa<NonValueTensorType>()) if (tensor.getType().isa<NonValueTensorType>())
tensor = builder.create<CopyToValueTensorOp>(loc, tensor); tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
if (newType.isa<NonValueTensorType>()) if (isa<NonValueTensorType>(newType))
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor); tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
return tensor; return tensor;
@ -132,11 +131,11 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
static Value getScalarIntValue(Value input, Location loc, static Value getScalarIntValue(Value input, Location loc,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto inputType = input.getType(); auto inputType = input.getType();
if (inputType.isa<Torch::IntType>()) { if (isa<Torch::IntType>(inputType)) {
return input; return input;
} }
auto inputTensorType = inputType.dyn_cast<BaseTensorType>(); auto inputTensorType = dyn_cast<BaseTensorType>(inputType);
if (!inputTensorType) if (!inputTensorType)
return nullptr; return nullptr;
@ -166,11 +165,11 @@ static Value getScalarIntValue(Value input, Location loc,
static Value getScalarFloatValue(Value input, Location loc, static Value getScalarFloatValue(Value input, Location loc,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto inputType = input.getType(); auto inputType = input.getType();
if (inputType.isa<Torch::FloatType>()) { if (isa<Torch::FloatType>(inputType)) {
return input; return input;
} }
auto inputTensorType = inputType.dyn_cast<BaseTensorType>(); auto inputTensorType = dyn_cast<BaseTensorType>(inputType);
if (!inputTensorType) if (!inputTensorType)
return nullptr; return nullptr;
@ -273,7 +272,7 @@ LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
LogicalResult PrimListConstructOp::verify() { LogicalResult PrimListConstructOp::verify() {
auto resultType = getResult().getType(); auto resultType = getResult().getType();
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType(); auto resultElementType = dyn_cast<ListType>(resultType).getContainedType();
auto matchResultElementType = [&](Type type) { auto matchResultElementType = [&](Type type) {
return isValidSubtype(type, resultElementType); return isValidSubtype(type, resultElementType);
}; };
@ -606,7 +605,7 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
Type rhsType = rhs.getType(); Type rhsType = rhs.getType();
// If either type is a NoneType, make it be the lhsType. // If either type is a NoneType, make it be the lhsType.
if (rhsType.isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(rhsType)) {
std::swap(lhsType, rhsType); std::swap(lhsType, rhsType);
std::swap(lhs, rhs); std::swap(lhs, rhs);
} }
@ -615,14 +614,14 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
// If both types are the singleton `!torch.none` type, then we don't even need // If both types are the singleton `!torch.none` type, then we don't even need
// to look at the values. // to look at the values.
if (lhsType.isa<Torch::NoneType>() && rhsType.isa<Torch::NoneType>()) if (isa<Torch::NoneType>(lhsType) && isa<Torch::NoneType>(rhsType))
return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue); return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue);
// If neither type is a subtype of the other, then the result is false. // If neither type is a subtype of the other, then the result is false.
// TODO: Implement and use subtype infra for this. // TODO: Implement and use subtype infra for this.
// For now, check a specific case. // For now, check a specific case.
// If the rhs is not OptionalType, then we know it cannot be None. // If the rhs is not OptionalType, then we know it cannot be None.
if (lhsType.isa<Torch::NoneType>() && !rhsType.isa<Torch::OptionalType>()) { if (isa<Torch::NoneType>(lhsType) && !isa<Torch::OptionalType>(rhsType)) {
return IntegerAttr::get(IntegerType::get(op->getContext(), 1), return IntegerAttr::get(IntegerType::get(op->getContext(), 1),
!equalIsTrue); !equalIsTrue);
} }
@ -640,9 +639,9 @@ OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
auto step = adaptor.getStep(); auto step = adaptor.getStep();
if (!lo || !hi || !step) if (!lo || !hi || !step)
return nullptr; return nullptr;
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue(); auto loInt = dyn_cast_or_null<IntegerAttr>(lo).getValue();
auto hiInt = hi.dyn_cast_or_null<IntegerAttr>().getValue(); auto hiInt = dyn_cast_or_null<IntegerAttr>(hi).getValue();
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue(); auto stepInt = dyn_cast_or_null<IntegerAttr>(step).getValue();
// TODO: Implement folding for negative steps. // TODO: Implement folding for negative steps.
if (stepInt.isNegative()) if (stepInt.isNegative())
return nullptr; return nullptr;
@ -650,7 +649,7 @@ OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
// r[i] = lo + step*i such that i >= 0 and r[i] < hi // r[i] = lo + step*i such that i >= 0 and r[i] < hi
// So maximize `i` such that lo + step * i < hi // So maximize `i` such that lo + step * i < hi
// ==> i == ceildiv(hi - lo, step) // ==> i == ceildiv(hi - lo, step)
return IntegerAttr::get(lo.cast<TypedAttr>().getType(), return IntegerAttr::get(cast<TypedAttr>(lo).getType(),
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt, llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
APInt::Rounding::UP)); APInt::Rounding::UP));
} }
@ -665,10 +664,10 @@ OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
auto step = adaptor.getStep(); auto step = adaptor.getStep();
if (!index || !start || !step) if (!index || !start || !step)
return nullptr; return nullptr;
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue(); auto indexInt = dyn_cast_or_null<IntegerAttr>(index).getValue();
auto startInt = start.dyn_cast_or_null<IntegerAttr>().getValue(); auto startInt = dyn_cast_or_null<IntegerAttr>(start).getValue();
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue(); auto stepInt = dyn_cast_or_null<IntegerAttr>(step).getValue();
return IntegerAttr::get(index.cast<TypedAttr>().getType(), return IntegerAttr::get(cast<TypedAttr>(index).getType(),
startInt + stepInt * indexInt); startInt + stepInt * indexInt);
} }
@ -2768,9 +2767,9 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns(
Value constValue; Value constValue;
Attribute value = op.getValueAttr(); Attribute value = op.getValueAttr();
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) { if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
constValue = rewriter.create<Torch::ConstantFloatOp>(loc, floatValue); constValue = rewriter.create<Torch::ConstantFloatOp>(loc, floatValue);
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) { } else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
constValue = rewriter.create<Torch::ConstantIntOp>(loc, intValue); constValue = rewriter.create<Torch::ConstantIntOp>(loc, intValue);
} else { } else {
return failure(); return failure();
@ -3192,9 +3191,9 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
BinaryFloatOperatorFn f) { BinaryFloatOperatorFn f) {
double lhs, rhs; double lhs, rhs;
auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool { auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool {
if (auto intLhs = attr.dyn_cast_or_null<IntegerAttr>()) { if (auto intLhs = dyn_cast_or_null<IntegerAttr>(attr)) {
value = static_cast<double>(intLhs.getValue().getSExtValue()); value = static_cast<double>(intLhs.getValue().getSExtValue());
} else if (auto floatLhs = attr.dyn_cast_or_null<FloatAttr>()) { } else if (auto floatLhs = dyn_cast_or_null<FloatAttr>(attr)) {
value = floatLhs.getValue().convertToDouble(); value = floatLhs.getValue().convertToDouble();
} else { } else {
return false; return false;
@ -3945,7 +3944,7 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
} }
Type resultType = getResult().getType(); Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>(); BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
if (!resultTensorType || !resultTensorType.hasDtype() || if (!resultTensorType || !resultTensorType.hasDtype() ||
!resultTensorType.hasSizes()) { !resultTensorType.hasSizes()) {
return nullptr; return nullptr;
@ -3966,11 +3965,11 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
auto elementType = shapedty.getElementType(); auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) { if (isa<IntegerType>(elementType)) {
Attribute attribute = IntegerAttr::get(elementType, 1); Attribute attribute = IntegerAttr::get(elementType, 1);
return DenseElementsAttr::get(shapedty, attribute); return DenseElementsAttr::get(shapedty, attribute);
} }
if (elementType.isa<FloatType>()) { if (isa<FloatType>(elementType)) {
Attribute attribute = FloatAttr::get(elementType, 1.0); Attribute attribute = FloatAttr::get(elementType, 1.0);
return DenseElementsAttr::get(shapedty, attribute); return DenseElementsAttr::get(shapedty, attribute);
} }
@ -3984,7 +3983,7 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
} }
Type resultType = getResult().getType(); Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>(); BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
if (!resultTensorType || !resultTensorType.hasDtype() || if (!resultTensorType || !resultTensorType.hasDtype() ||
!resultTensorType.hasSizes()) { !resultTensorType.hasSizes()) {
return nullptr; return nullptr;
@ -4006,11 +4005,11 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
} }
auto elementType = shapedty.getElementType(); auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) { if (isa<IntegerType>(elementType)) {
Attribute attribute = IntegerAttr::get(elementType, 0); Attribute attribute = IntegerAttr::get(elementType, 0);
return DenseElementsAttr::get(shapedty, attribute); return DenseElementsAttr::get(shapedty, attribute);
} }
if (elementType.isa<FloatType>()) { if (isa<FloatType>(elementType)) {
Attribute attribute = FloatAttr::get(elementType, 0.0); Attribute attribute = FloatAttr::get(elementType, 0.0);
return DenseElementsAttr::get(shapedty, attribute); return DenseElementsAttr::get(shapedty, attribute);
} }
@ -4025,7 +4024,7 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
} }
Type resultType = getResult().getType(); Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>(); BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
if (!resultTensorType || !resultTensorType.hasDtype() || if (!resultTensorType || !resultTensorType.hasDtype() ||
!resultTensorType.hasSizes()) { !resultTensorType.hasSizes()) {
return nullptr; return nullptr;
@ -4043,14 +4042,14 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
mlir::RankedTensorType::get(sizes, resultTensorType.getDtype()); mlir::RankedTensorType::get(sizes, resultTensorType.getDtype());
auto elementType = shapedty.getElementType(); auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) { if (isa<IntegerType>(elementType)) {
int64_t value = 0; int64_t value = 0;
if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) { if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) {
Attribute attribute = IntegerAttr::get(elementType, value); Attribute attribute = IntegerAttr::get(elementType, value);
return DenseElementsAttr::get(shapedty, attribute); return DenseElementsAttr::get(shapedty, attribute);
} }
} }
if (elementType.isa<FloatType>()) { if (isa<FloatType>(elementType)) {
double value = 0.0; double value = 0.0;
if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) { if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
Attribute attribute = FloatAttr::get(elementType, value); Attribute attribute = FloatAttr::get(elementType, value);
@ -4631,15 +4630,14 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
auto initialize = cast<InitializeGlobalSlotsOp>(getBody()->getTerminator()); auto initialize = cast<InitializeGlobalSlotsOp>(getBody()->getTerminator());
for (Attribute symName : initialize.getSlotSymNames()) { for (Attribute symName : initialize.getSlotSymNames()) {
auto wasInserted = initializedGlobalSlots auto wasInserted = initializedGlobalSlots
.insert(symName.cast<FlatSymbolRefAttr>().getAttr()) .insert(cast<FlatSymbolRefAttr>(symName).getAttr())
.second; .second;
if (!wasInserted) if (!wasInserted)
return initialize.emitError("duplicate initialization of global slot: ") return initialize.emitError("duplicate initialization of global slot: ")
<< symName; << symName;
} }
auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) { auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) {
return lhs.cast<StringAttr>().getValue() < return cast<StringAttr>(lhs).getValue() < cast<StringAttr>(rhs).getValue();
rhs.cast<StringAttr>().getValue();
}; };
auto known = llvm::to_vector(knownGlobalSlots); auto known = llvm::to_vector(knownGlobalSlots);
llvm::sort(known, lessThanByStringValue); llvm::sort(known, lessThanByStringValue);
@ -4652,7 +4650,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
InFlightDiagnostic diag = initialize.emitOpError( InFlightDiagnostic diag = initialize.emitOpError(
"must have one initializer for each global slot in the module"); "must have one initializer for each global slot in the module");
for (auto knownGlobalSlot : known) { for (auto knownGlobalSlot : known) {
auto symName = FlatSymbolRefAttr::get(knownGlobalSlot.cast<StringAttr>()); auto symName = FlatSymbolRefAttr::get(cast<StringAttr>(knownGlobalSlot));
if (!initializedGlobalSlots.count(knownGlobalSlot)) { if (!initializedGlobalSlots.count(knownGlobalSlot)) {
diag.attachNote( diag.attachNote(
symbolTable.lookup<GlobalSlotOp>(symName.getAttr()).getLoc()) symbolTable.lookup<GlobalSlotOp>(symName.getAttr()).getLoc())
@ -4663,7 +4661,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
if (!knownGlobalSlots.count(initializedGlobalSlot)) { if (!knownGlobalSlots.count(initializedGlobalSlot)) {
diag.attachNote().append( diag.attachNote().append(
"unexpected global slot initializer for non-existent global slot ", "unexpected global slot initializer for non-existent global slot ",
FlatSymbolRefAttr::get(initializedGlobalSlot.cast<StringAttr>())); FlatSymbolRefAttr::get(cast<StringAttr>(initializedGlobalSlot)));
} }
} }
return diag; return diag;

View File

@ -29,7 +29,7 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
// For a UnionType to be a subtype, all of its contained types must be // For a UnionType to be a subtype, all of its contained types must be
// subtypes. // subtypes.
if (auto unionType = subtype.dyn_cast<UnionType>()) { if (auto unionType = dyn_cast<UnionType>(subtype)) {
for (auto containedType : unionType.getContainedTypes()) { for (auto containedType : unionType.getContainedTypes()) {
if (!isValidSubtype(containedType, type)) if (!isValidSubtype(containedType, type))
return false; return false;
@ -37,17 +37,17 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
return true; return true;
} }
if (auto any = type.dyn_cast<AnyType>()) if (auto any = dyn_cast<AnyType>(type))
return true; return true;
if (auto number = type.dyn_cast<NumberType>()) if (auto number = dyn_cast<NumberType>(type))
return subtype.isa<IntType>() || subtype.isa<Torch::FloatType>(); return isa<IntType>(subtype) || isa<Torch::FloatType>(subtype);
if (auto optional = type.dyn_cast<OptionalType>()) if (auto optional = dyn_cast<OptionalType>(type))
return isValidSubtype(subtype, optional.getContainedType()) || return isValidSubtype(subtype, optional.getContainedType()) ||
subtype.isa<Torch::NoneType>(); isa<Torch::NoneType>(subtype);
if (auto unionType = type.dyn_cast<UnionType>()) { if (auto unionType = dyn_cast<UnionType>(type)) {
for (auto containedType : unionType.getContainedTypes()) { for (auto containedType : unionType.getContainedTypes()) {
if (isValidSubtype(subtype, containedType)) if (isValidSubtype(subtype, containedType))
return true; return true;
@ -55,10 +55,10 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
return false; return false;
} }
if (auto tuple = type.dyn_cast<Torch::TupleType>()) { if (auto tuple = dyn_cast<Torch::TupleType>(type)) {
if (!subtype.isa<Torch::TupleType>()) if (!isa<Torch::TupleType>(subtype))
return false; return false;
auto subtypes = subtype.cast<Torch::TupleType>().getContainedTypes(); auto subtypes = cast<Torch::TupleType>(subtype).getContainedTypes();
auto types = tuple.getContainedTypes(); auto types = tuple.getContainedTypes();
if (subtypes.size() != types.size()) if (subtypes.size() != types.size())
return false; return false;
@ -69,14 +69,14 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
return true; return true;
} }
auto subtypeTensorType = subtype.dyn_cast<BaseTensorType>(); auto subtypeTensorType = dyn_cast<BaseTensorType>(subtype);
auto typeTensorType = type.dyn_cast<BaseTensorType>(); auto typeTensorType = dyn_cast<BaseTensorType>(type);
if (subtypeTensorType && typeTensorType) { if (subtypeTensorType && typeTensorType) {
// Check that both tensors have the same `BaseTensorType` subtype. // Check that both tensors have the same `BaseTensorType` subtype.
// TODO: This is not subtyping according to PEP 483. See description // TODO: This is not subtyping according to PEP 483. See description
// of NonValueTensorType. // of NonValueTensorType.
if (subtypeTensorType.isa<ValueTensorType>() != if (isa<ValueTensorType>(subtypeTensorType) !=
typeTensorType.isa<ValueTensorType>()) isa<ValueTensorType>(typeTensorType))
return false; return false;
// `type` must not have more static information than `subtype`, and `type` // `type` must not have more static information than `subtype`, and `type`
@ -181,23 +181,23 @@ void Torch::UnionType::print(AsmPrinter &printer) const {
static bool isValidTorchDtype(Type dtype) { static bool isValidTorchDtype(Type dtype) {
// For complex types, get the underlying element type // For complex types, get the underlying element type
if (dtype.isa<ComplexType>()) { if (isa<ComplexType>(dtype)) {
dtype = dtype.cast<ComplexType>().getElementType(); dtype = cast<ComplexType>(dtype).getElementType();
} }
// Torch quantized types. // Torch quantized types.
if (dtype.isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt32Type>()) if (isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt32Type>(dtype))
return true; return true;
// Builtin floating point types. // Builtin floating point types.
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>()) if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
return true; return true;
if (dtype.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType, if (dtype.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType>()) Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
return true; return true;
if (dtype.isa<Torch::StringType>()) if (isa<Torch::StringType>(dtype))
return true; return true;
// Builtin integer types. // Builtin integer types.
if (IntegerType type = dtype.dyn_cast<IntegerType>()) { if (IntegerType type = dyn_cast<IntegerType>(dtype)) {
if (type.isSignless() && type.getWidth() == 1) if (type.isSignless() && type.getWidth() == 1)
return true; return true;
if (type.isSigned()) { if (type.isSigned()) {
@ -273,7 +273,7 @@ verifyTensorType(function_ref<InFlightDiagnostic()> emitError,
} }
} }
} }
if (!optionalSparsity.isa<sparse_tensor::SparseTensorEncodingAttr>()) { if (!isa<sparse_tensor::SparseTensorEncodingAttr>(optionalSparsity)) {
emitError() << "invalid sparsity encoding attribute"; emitError() << "invalid sparsity encoding attribute";
return failure(); return failure();
} }
@ -441,12 +441,12 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
} }
static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
if (auto floatType = dtype.dyn_cast<mlir::FloatType>()) { if (auto floatType = dyn_cast<mlir::FloatType>(dtype)) {
return dtype; return dtype;
} else if (auto integerType = dtype.dyn_cast<IntegerType>()) { } else if (auto integerType = dyn_cast<IntegerType>(dtype)) {
return IntegerType::get(context, integerType.getWidth(), return IntegerType::get(context, integerType.getWidth(),
IntegerType::Signless); IntegerType::Signless);
} else if (dtype.isa<mlir::ComplexType>()) { } else if (isa<mlir::ComplexType>(dtype)) {
return dtype; return dtype;
} }
@ -502,8 +502,8 @@ void ValueTensorType::print(AsmPrinter &printer) const {
} }
Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
assert(((lhs.isa<ValueTensorType>() && rhs.isa<ValueTensorType>()) || assert(((isa<ValueTensorType>(lhs) && isa<ValueTensorType>(rhs)) ||
(lhs.isa<NonValueTensorType>() && rhs.isa<NonValueTensorType>())) && (isa<NonValueTensorType>(lhs) && isa<NonValueTensorType>(rhs))) &&
"expected lhs and rhs to have same sense of value semantics"); "expected lhs and rhs to have same sense of value semantics");
// First, calculate the dtype. // First, calculate the dtype.
@ -566,21 +566,21 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
// linkage) and the predicates themselves can't be added/used in the // linkage) and the predicates themselves can't be added/used in the
// specification of the parameters of the Torch_DictType. // specification of the parameters of the Torch_DictType.
static bool isAnyTorchDictKeyType(Type type) { static bool isAnyTorchDictKeyType(Type type) {
return type.isa<Torch::AnyType>() || type.isa<Torch::IntType>() || return isa<Torch::AnyType>(type) || isa<Torch::IntType>(type) ||
type.isa<Torch::BoolType>() || type.isa<Torch::FloatType>() || isa<Torch::BoolType>(type) || isa<Torch::FloatType>(type) ||
type.isa<Torch::StringType>() || type.isa<Torch::BaseTensorType>(); isa<Torch::StringType>(type) || isa<Torch::BaseTensorType>(type);
} }
static bool isAnyTorchType(Type type) { static bool isAnyTorchType(Type type) {
return isValidSubtype(type, Torch::NumberType::get(type.getContext())) || return isValidSubtype(type, Torch::NumberType::get(type.getContext())) ||
type.isa<Torch::BaseTensorType>() || type.isa<Torch::AnyType>() || isa<Torch::BaseTensorType>(type) || isa<Torch::AnyType>(type) ||
type.isa<Torch::BoolType>() || type.isa<Torch::DictType>() || isa<Torch::BoolType>(type) || isa<Torch::DictType>(type) ||
type.isa<Torch::DeviceType>() || type.isa<Torch::GeneratorType>() || isa<Torch::DeviceType>(type) || isa<Torch::GeneratorType>(type) ||
type.isa<Torch::ListType>() || type.isa<Torch::LinearParamsType>() || isa<Torch::ListType>(type) || isa<Torch::LinearParamsType>(type) ||
type.isa<Torch::NumberType>() || type.isa<Torch::NnModuleType>() || isa<Torch::NumberType>(type) || isa<Torch::NnModuleType>(type) ||
type.isa<Torch::NoneType>() || type.isa<Torch::OptionalType>() || isa<Torch::NoneType>(type) || isa<Torch::OptionalType>(type) ||
type.isa<Torch::StringType>() || type.isa<Torch::TupleType>() || isa<Torch::StringType>(type) || isa<Torch::TupleType>(type) ||
type.isa<Torch::UnionType>(); isa<Torch::UnionType>(type);
} }
LogicalResult LogicalResult

View File

@ -53,7 +53,7 @@ public:
auto typeBoundAttr = auto typeBoundAttr =
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent); func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type(); Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type();
if (!bound.isa<ValueTensorType>()) if (!isa<ValueTensorType>(bound))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
func, "unimplemented: preserving aliasing for non-value-semantic " func, "unimplemented: preserving aliasing for non-value-semantic "
"type bounds"); "type bounds");
@ -72,10 +72,10 @@ public:
SmallVector<Type> newResultTypes; SmallVector<Type> newResultTypes;
for (auto type : func.getFunctionType().getResults()) { for (auto type : func.getFunctionType().getResults()) {
if (auto none = type.dyn_cast<Torch::NoneType>()) { if (auto none = dyn_cast<Torch::NoneType>(type)) {
continue; continue;
} }
if (auto tuple = type.dyn_cast<Torch::TupleType>()) { if (auto tuple = dyn_cast<Torch::TupleType>(type)) {
llvm::append_range(newResultTypes, tuple.getContainedTypes()); llvm::append_range(newResultTypes, tuple.getContainedTypes());
continue; continue;
} }
@ -133,12 +133,12 @@ public:
int newOpResultIdx = 0; int newOpResultIdx = 0;
SmallVector<Value> newResults; SmallVector<Value> newResults;
for (auto type : call.getResultTypes()) { for (auto type : call.getResultTypes()) {
if (type.isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(type)) {
newResults.push_back( newResults.push_back(
rewriter.create<ConstantNoneOp>(call.getLoc(), type)); rewriter.create<ConstantNoneOp>(call.getLoc(), type));
continue; continue;
} }
if (type.isa<Torch::TupleType>()) { if (isa<Torch::TupleType>(type)) {
newResults.push_back(rewriter.create<PrimTupleConstructOp>( newResults.push_back(rewriter.create<PrimTupleConstructOp>(
call.getLoc(), type, newCall.getResults())); call.getLoc(), type, newCall.getResults()));
continue; continue;

View File

@ -1386,7 +1386,7 @@ static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
unNormalizedExp, sum); unNormalizedExp, sum);
if (resultType != accumulatorType) if (resultType != accumulatorType)
result = convertTensorToDtype(rewriter, loc, result, result = convertTensorToDtype(rewriter, loc, result,
resultType.cast<BaseTensorType>().getDtype()); cast<BaseTensorType>(resultType).getDtype());
return result; return result;
} }
@ -1405,7 +1405,7 @@ public:
op, "expected result type to have a dtype"); op, "expected result type to have a dtype");
} }
Type resultTensorDtype = resultTensorType.getDtype(); Type resultTensorDtype = resultTensorType.getDtype();
if (!resultTensorDtype.isa<mlir::FloatType>()) if (!isa<mlir::FloatType>(resultTensorDtype))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only support floating-point type"); "Only support floating-point type");
@ -1980,7 +1980,7 @@ public:
} }
Type dtype = resType.getDtype(); Type dtype = resType.getDtype();
if (dtype.isa<mlir::ComplexType>()) { if (isa<mlir::ComplexType>(dtype)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "lowering of aten.linalg_cross for complex inputs dtype is " op, "lowering of aten.linalg_cross for complex inputs dtype is "
"currently unimplemented"); "currently unimplemented");
@ -2015,7 +2015,7 @@ public:
Value none = rewriter.create<ConstantNoneOp>(loc); Value none = rewriter.create<ConstantNoneOp>(loc);
// idx = torch.arange(3) // idx = torch.arange(3)
auto outType = opType.dyn_cast<BaseTensorType>(); auto outType = dyn_cast<BaseTensorType>(opType);
auto arangeType = outType.getWithSizesAndDtype( auto arangeType = outType.getWithSizesAndDtype(
llvm::ArrayRef<int64_t>(3), llvm::ArrayRef<int64_t>(3),
IntegerType::get(op.getContext(), 64, IntegerType::Signed)); IntegerType::get(op.getContext(), 64, IntegerType::Signed));
@ -5848,7 +5848,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
Value keepDim = op.getKeepdim(); Value keepDim = op.getKeepdim();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>(); BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
Type outputType = op.getType(); Type outputType = op.getType();
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>(); BaseTensorType outputTensorType = cast<BaseTensorType>(outputType);
if (!outputTensorType.hasDtype()) { if (!outputTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"expected result type to have a dtype"); "expected result type to have a dtype");
@ -5893,7 +5893,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
Type meanDimResultType = inputTensorTy; Type meanDimResultType = inputTensorTy;
for (unsigned i = 0; i < dimListElements.size(); i++) for (unsigned i = 0; i < dimListElements.size(); i++)
meanDimResultType = computeReductionType( meanDimResultType = computeReductionType(
rewriter, op, meanDimResultType.cast<BaseTensorType>(), rewriter, op, cast<BaseTensorType>(meanDimResultType),
dimListElements[i], dimListElements[i],
/*keepDim=*/true); /*keepDim=*/true);
@ -6189,7 +6189,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Type resultType = op.getType(); Type resultType = op.getType();
BaseTensorType resultTensorType = resultType.cast<BaseTensorType>(); BaseTensorType resultTensorType = cast<BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) { if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype"); op, "expected result type to have a dtype");

View File

@ -207,7 +207,7 @@ public:
return failure(); return failure();
Type resultETy = resultTy.getDtype(); Type resultETy = resultTy.getDtype();
if (!resultETy.isa<mlir::FloatType>()) if (!isa<mlir::FloatType>(resultETy))
return failure(); return failure();
Value lhsScale; Value lhsScale;

View File

@ -183,13 +183,13 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
} }
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) { LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
if (Value value = point.dyn_cast<Value>()) { if (Value value = dyn_cast<Value>(point)) {
bool isSafe = isValueSafeTransferFunction(value); bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value); auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(state, state->setSafe(isSafe)); propagateIfChanged(state, state->setSafe(isSafe));
// Handle GlobalSlotGetOp's. // Handle GlobalSlotGetOp's.
if (auto opResult = value.dyn_cast<OpResult>()) { if (auto opResult = dyn_cast<OpResult>(value)) {
if (auto globalSlotGet = if (auto globalSlotGet =
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) { dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>( auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
@ -205,7 +205,7 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
return success(); return success();
} }
if (auto *genericProgramPoint = point.dyn_cast<GenericProgramPoint *>()) { if (auto *genericProgramPoint = dyn_cast<GenericProgramPoint *>(point)) {
if (auto *flatSymbolRefPoint = if (auto *flatSymbolRefPoint =
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) { dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
if (initializeGlobalSlotsOp) { if (initializeGlobalSlotsOp) {
@ -396,7 +396,7 @@ class InlineGlobalSlotsPass
// This could be left to SymbolDCE but it's not hard to do here. // This could be left to SymbolDCE but it's not hard to do here.
for (FlatSymbolRefAttr symName : for (FlatSymbolRefAttr symName :
llvm::map_range(safeToInline, [](Attribute attr) { llvm::map_range(safeToInline, [](Attribute attr) {
return attr.cast<FlatSymbolRefAttr>(); return cast<FlatSymbolRefAttr>(attr);
})) { })) {
auto globalSlot = auto globalSlot =
symbolTable.lookup<Torch::GlobalSlotOp>(symName.getValue()); symbolTable.lookup<Torch::GlobalSlotOp>(symName.getValue());

View File

@ -46,14 +46,14 @@ static LogicalResult checkType(Operation *op, Type type,
// can statically pattern match and eliminate from the program. // can statically pattern match and eliminate from the program.
// For example, a tensor operand might be optional, and the backend // For example, a tensor operand might be optional, and the backend
// will pattern-match statically whether it is passed as a tensor or None. // will pattern-match statically whether it is passed as a tensor or None.
if (type.isa<Torch::NoneType, Torch::StringType>()) if (isa<Torch::NoneType, Torch::StringType>(type))
return success(); return success();
// We blanket prohibit non-value-semantic tensors. // We blanket prohibit non-value-semantic tensors.
// All of our backends are currently based on value-semantic tensors, so // All of our backends are currently based on value-semantic tensors, so
// we consider it our responsibility to lower all non-value-semantic tensors // we consider it our responsibility to lower all non-value-semantic tensors
// to value-semantic tensors. // to value-semantic tensors.
if (type.isa<NonValueTensorType>()) { if (isa<NonValueTensorType>(type)) {
if (actuallyEmitDiagnostics) { if (actuallyEmitDiagnostics) {
return op return op
->emitError("unsupported by backend contract: non-value tensor type") ->emitError("unsupported by backend contract: non-value tensor type")
@ -84,7 +84,7 @@ static LogicalResult checkType(Operation *op, Type type,
// have an sufficiently rich system for representing PyTorch type promotion // have an sufficiently rich system for representing PyTorch type promotion
// rules. So we consider it our responsibility to ensure that all dtypes are // rules. So we consider it our responsibility to ensure that all dtypes are
// statically known. // statically known.
if (auto tensorType = type.dyn_cast<ValueTensorType>()) { if (auto tensorType = dyn_cast<ValueTensorType>(type)) {
if (!tensorType.hasSizes()) { if (!tensorType.hasSizes()) {
if (actuallyEmitDiagnostics) { if (actuallyEmitDiagnostics) {
return op return op
@ -115,7 +115,7 @@ static LogicalResult checkType(Operation *op, Type type,
// Optional types are also in the category of types which we don't expect // Optional types are also in the category of types which we don't expect
// backends to dynamically compute with, but they can be pattern matched // backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary. // in many cases that are practically necessary.
if (auto optionalType = type.dyn_cast<OptionalType>()) { if (auto optionalType = dyn_cast<OptionalType>(type)) {
// TODO: Be stricter about tensor types. // TODO: Be stricter about tensor types.
// See comment below for ListType. // See comment below for ListType.
if (optionalType.getContainedType().isa<ValueTensorType>()) if (optionalType.getContainedType().isa<ValueTensorType>())
@ -127,7 +127,7 @@ static LogicalResult checkType(Operation *op, Type type,
// backends to dynamically compute with, but they can be pattern matched // backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary. For example, the // in many cases that are practically necessary. For example, the
// strides of a convolution op are represented as a list. // strides of a convolution op are represented as a list.
if (auto listType = type.dyn_cast<ListType>()) { if (auto listType = dyn_cast<ListType>(type)) {
// TODO: Be stricter about tensor types. // TODO: Be stricter about tensor types.
// For the moment, there are cases (such as for torch.cat) where we end // For the moment, there are cases (such as for torch.cat) where we end
// up with `!torch.list<vtensor>` which doesn't have shape or dtype in // up with `!torch.list<vtensor>` which doesn't have shape or dtype in
@ -141,7 +141,7 @@ static LogicalResult checkType(Operation *op, Type type,
// Tuple types are also in the category of types which we don't expect // Tuple types are also in the category of types which we don't expect
// backends to dynamically compute with, but they can be pattern matched // backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary. // in many cases that are practically necessary.
if (auto tupleType = type.dyn_cast<Torch::TupleType>()) { if (auto tupleType = dyn_cast<Torch::TupleType>(type)) {
for (auto containedType : tupleType.getContainedTypes()) { for (auto containedType : tupleType.getContainedTypes()) {
if (failed(checkType(op, containedType, actuallyEmitDiagnostics))) if (failed(checkType(op, containedType, actuallyEmitDiagnostics)))
return failure(); return failure();

View File

@ -140,7 +140,7 @@ public:
auto returnOp = ops.returnOp.value(); auto returnOp = ops.returnOp.value();
for (auto operand : llvm::enumerate(returnOp->getOperands())) { for (auto operand : llvm::enumerate(returnOp->getOperands())) {
auto type = operand.value().getType(); auto type = operand.value().getType();
if (!type.isa<NonValueTensorType>()) if (!isa<NonValueTensorType>(type))
continue; continue;
originalReturnTypes[operand.index()] = type; originalReturnTypes[operand.index()] = type;
} }

View File

@ -38,15 +38,15 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter,
} }
static Type getContainerOrTensorTypeWithValueSemantics(Type type) { static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
if (auto optionalType = type.dyn_cast<OptionalType>()) { if (auto optionalType = dyn_cast<OptionalType>(type)) {
Type newContainedType = getContainerOrTensorTypeWithValueSemantics( Type newContainedType = getContainerOrTensorTypeWithValueSemantics(
optionalType.getContainedType()); optionalType.getContainedType());
return OptionalType::get(newContainedType); return OptionalType::get(newContainedType);
} else if (auto listType = type.dyn_cast<ListType>()) { } else if (auto listType = dyn_cast<ListType>(type)) {
Type newContainedType = Type newContainedType =
getContainerOrTensorTypeWithValueSemantics(listType.getContainedType()); getContainerOrTensorTypeWithValueSemantics(listType.getContainedType());
return ListType::get(newContainedType); return ListType::get(newContainedType);
} else if (auto tensorType = type.dyn_cast<NonValueTensorType>()) { } else if (auto tensorType = dyn_cast<NonValueTensorType>(type)) {
return tensorType.getWithValueSemantics(); return tensorType.getWithValueSemantics();
} else { } else {
return nullptr; return nullptr;
@ -92,10 +92,10 @@ public:
SmallVector<Value> newOperands; SmallVector<Value> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) { for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType(); Type operandType = opOperand.get().getType();
if (operandType.isa<NonValueTensorType>()) { if (isa<NonValueTensorType>(operandType)) {
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(), opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
opOperand.get())); opOperand.get()));
} else if (auto listType = operandType.dyn_cast<ListType>()) { } else if (auto listType = dyn_cast<ListType>(operandType)) {
if (!(listType.getContainedType().isa<NonValueTensorType>() || if (!(listType.getContainedType().isa<NonValueTensorType>() ||
listType.getContainedType().isa<OptionalType>())) listType.getContainedType().isa<OptionalType>()))
continue; continue;
@ -144,7 +144,7 @@ public:
} }
opOperand.set(rewriter.create<PrimListConstructOp>( opOperand.set(rewriter.create<PrimListConstructOp>(
op->getLoc(), newListType, newListElements)); op->getLoc(), newListType, newListElements));
} else if (auto optionalType = operandType.dyn_cast<OptionalType>()) { } else if (auto optionalType = dyn_cast<OptionalType>(operandType)) {
// TODO: A more general way to handle the optional type is to // TODO: A more general way to handle the optional type is to
// introduce a `copy.to_optional_vtensor` op. // introduce a `copy.to_optional_vtensor` op.
if (!optionalType.getContainedType().isa<NonValueTensorType>()) if (!optionalType.getContainedType().isa<NonValueTensorType>())
@ -450,7 +450,7 @@ struct ReduceOpVariantsPass
auto hasValueSemantics = [](Type t) { auto hasValueSemantics = [](Type t) {
// TODO: Make this an allowlist based on a closed torch dialect // TODO: Make this an allowlist based on a closed torch dialect
// type system. // type system.
if (auto tensorType = t.dyn_cast<NonValueTensorType>()) { if (auto tensorType = dyn_cast<NonValueTensorType>(t)) {
return false; return false;
} }
return true; return true;

View File

@ -170,7 +170,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
if (operandType == desiredType) if (operandType == desiredType)
return operand; return operand;
if (desiredType.isa<Torch::AnyType>()) { if (isa<Torch::AnyType>(desiredType)) {
// Generator's are currently passed as Any because TorchScript cannot // Generator's are currently passed as Any because TorchScript cannot
// compile a function with Generator type arguments. // compile a function with Generator type arguments.
// Ignoring that hack, this is a correct handling of Any type should we need // Ignoring that hack, this is a correct handling of Any type should we need
@ -180,8 +180,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// The type `!torch.number` can be an `int`, `float`, or `complex`. // The type `!torch.number` can be an `int`, `float`, or `complex`.
// TODO: Add a new type `Torch::ComplexType` to handle the complex case. // TODO: Add a new type `Torch::ComplexType` to handle the complex case.
if (desiredType.isa<Torch::NumberType>() && if (isa<Torch::NumberType>(desiredType) &&
operandType.isa<Torch::IntType, Torch::FloatType>()) { isa<Torch::IntType, Torch::FloatType>(operandType)) {
return b.create<DerefineOp>(loc, desiredType, operand).getResult(); return b.create<DerefineOp>(loc, desiredType, operand).getResult();
} }
@ -189,7 +189,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// `Scalar` inputs. At compile time, such inputs will usually be // `Scalar` inputs. At compile time, such inputs will usually be
// resolved to an `int`, `float`, or `None` so we need to derefine // resolved to an `int`, `float`, or `None` so we need to derefine
// to match the library function signature. // to match the library function signature.
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) { if (auto unionType = dyn_cast<Torch::UnionType>(desiredType)) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType return containedType
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>(); .isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
@ -200,8 +200,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// Operands with type `!torch.none` correspond to library function inputs with // Operands with type `!torch.none` correspond to library function inputs with
// types like `!torch.optional<...>` or `!torch.union<..., none>`, so here the // types like `!torch.optional<...>` or `!torch.union<..., none>`, so here the
// type is derefined to match the expected type of the library function. // type is derefined to match the expected type of the library function.
if (operandType.isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(operandType)) {
assert(!desiredType.isa<Torch::NoneType>() && assert(!isa<Torch::NoneType>(desiredType) &&
"Don't expect library functions to have NoneType parameters"); "Don't expect library functions to have NoneType parameters");
return b.create<DerefineOp>(loc, desiredType, operand).getResult(); return b.create<DerefineOp>(loc, desiredType, operand).getResult();
} }
@ -211,8 +211,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// dtype of input scalars. However, this also means we sometimes have to // dtype of input scalars. However, this also means we sometimes have to
// manually turn `Scalar`s into `float`s when inserting the shape functions // manually turn `Scalar`s into `float`s when inserting the shape functions
// into the IR. // into the IR.
if (operandType.isa<Torch::NumberType>() && if (isa<Torch::NumberType>(operandType) &&
desiredType.isa<Torch::FloatType>()) { isa<Torch::FloatType>(desiredType)) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult(); return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
} }
@ -224,8 +224,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// type). // type).
// A case where this happens is `!torch.optional<vtensor>` -> // A case where this happens is `!torch.optional<vtensor>` ->
// `!torch.optional<list<int>>>`. // `!torch.optional<list<int>>>`.
if (auto operandOptionalType = operandType.dyn_cast<Torch::OptionalType>()) { if (auto operandOptionalType = dyn_cast<Torch::OptionalType>(operandType)) {
if (desiredType.isa<Torch::OptionalType>()) { if (isa<Torch::OptionalType>(desiredType)) {
// if optional is None: // if optional is None:
// return derefine(None) // return derefine(None)
// else: // else:
@ -258,7 +258,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// If the desired type is OptionalType, then recursively adjust the operand to // If the desired type is OptionalType, then recursively adjust the operand to
// the contained type, then derefine it to `!torch.optional`. For example, // the contained type, then derefine it to `!torch.optional`. For example,
// `!torch.vtensor -> !torch.optional<list<int>>>`. // `!torch.vtensor -> !torch.optional<list<int>>>`.
if (auto desiredOptionalType = desiredType.dyn_cast<Torch::OptionalType>()) { if (auto desiredOptionalType = dyn_cast<Torch::OptionalType>(desiredType)) {
FailureOr<Value> adjusted = adjustFunctionArg( FailureOr<Value> adjusted = adjustFunctionArg(
b, loc, operand, desiredOptionalType.getContainedType(), b, loc, operand, desiredOptionalType.getContainedType(),
baseTransformation); baseTransformation);
@ -267,7 +267,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
return b.create<DerefineOp>(loc, desiredType, *adjusted).getResult(); return b.create<DerefineOp>(loc, desiredType, *adjusted).getResult();
} }
if (auto desiredListType = desiredType.dyn_cast<Torch::ListType>()) { if (auto desiredListType = dyn_cast<Torch::ListType>(desiredType)) {
// Pseudocode: // Pseudocode:
// //
// operand = ... // operand = ...
@ -311,7 +311,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// The library functions use `float` where the operator // The library functions use `float` where the operator
// signature uses `Scalar` (see comments in torch_ods_gen.py for // signature uses `Scalar` (see comments in torch_ods_gen.py for
// explanation). // explanation).
if (desiredType.isa<Torch::FloatType>() && if (isa<Torch::FloatType>(desiredType) &&
operand.getType().isa<Torch::IntType>()) { operand.getType().isa<Torch::IntType>()) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult(); return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
} }

View File

@ -29,7 +29,7 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
// Turn every tensor into a tuple of (tensor_rank, tensor_dtype) // Turn every tensor into a tuple of (tensor_rank, tensor_dtype)
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand, auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
Type desiredType) -> Value { Type desiredType) -> Value {
if (desiredType.isa<Torch::TupleType>() && if (isa<Torch::TupleType>(desiredType) &&
operand.getType().isa<Torch::BaseTensorType>()) { operand.getType().isa<Torch::BaseTensorType>()) {
Type intType = Torch::IntType::get(b.getContext()); Type intType = Torch::IntType::get(b.getContext());
Type sizeListType = Torch::ListType::get(intType); Type sizeListType = Torch::ListType::get(intType);

View File

@ -38,7 +38,7 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
Type desiredType) -> Value { Type desiredType) -> Value {
// The shape library functions have tensor operands replaced with // The shape library functions have tensor operands replaced with
// `!torch.list<int>` types for the shape. Get the sizes. // `!torch.list<int>` types for the shape. Get the sizes.
auto desiredListType = desiredType.dyn_cast<Torch::ListType>(); auto desiredListType = dyn_cast<Torch::ListType>(desiredType);
if (!desiredListType) if (!desiredListType)
return operand; return operand;
if (operand.getType().isa<Torch::BaseTensorType>() && if (operand.getType().isa<Torch::BaseTensorType>() &&

View File

@ -262,13 +262,13 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
originalResultType.template dyn_cast<BaseTensorType>()) { originalResultType.template dyn_cast<BaseTensorType>()) {
// If we didn't get any new information, there is nothing left for us to do. // If we didn't get any new information, there is nothing left for us to do.
updatedType = meetTensorTypes(originalBaseTensorType, updatedType = meetTensorTypes(originalBaseTensorType,
newResultType.cast<BaseTensorType>()); cast<BaseTensorType>(newResultType));
if (!updatedType || updatedType == originalBaseTensorType) if (!updatedType || updatedType == originalBaseTensorType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
calculateOp, "New type information does not refine old type"); calculateOp, "New type information does not refine old type");
} else if (auto originalResultType = } else if (auto originalResultType =
result.getType().template dyn_cast<Torch::NumberType>()) { result.getType().template dyn_cast<Torch::NumberType>()) {
if (!newResultType.isa<Torch::FloatType, Torch::IntType>()) { if (!isa<Torch::FloatType, Torch::IntType>(newResultType)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
calculateOp, calculateOp,
"Refinement of `NumberType` must be a `FloatType` or `IntType`"); "Refinement of `NumberType` must be a `FloatType` or `IntType`");
@ -291,10 +291,10 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
} }
if (!originalTypedValue) { if (!originalTypedValue) {
rewriter.setInsertionPointAfter(calculateOp); rewriter.setInsertionPointAfter(calculateOp);
if (originalResultType.isa<BaseTensorType>()) { if (isa<BaseTensorType>(originalResultType)) {
originalTypedValue = rewriter.create<TensorStaticInfoCastOp>( originalTypedValue = rewriter.create<TensorStaticInfoCastOp>(
loc, originalResultType, result); loc, originalResultType, result);
} else if (originalResultType.isa<Torch::NumberType>()) { } else if (isa<Torch::NumberType>(originalResultType)) {
originalTypedValue = originalTypedValue =
rewriter.create<DerefineOp>(loc, originalResultType, result); rewriter.create<DerefineOp>(loc, originalResultType, result);
} else { } else {
@ -314,14 +314,14 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
OpOperand &use = yieldValues->getOpOperand(resultNum); OpOperand &use = yieldValues->getOpOperand(resultNum);
Value def = use.get(); Value def = use.get();
Value newYieldedValue; Value newYieldedValue;
if (def.isa<OpResult>() && if (isa<OpResult>(def) &&
def.cast<OpResult>() cast<OpResult>(def)
.getDefiningOp() .getDefiningOp()
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) { ->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
newYieldedValue = def; newYieldedValue = def;
} else { } else {
rewriter.setInsertionPoint(yieldValues); rewriter.setInsertionPoint(yieldValues);
if (updatedType.isa<BaseTensorType>()) { if (isa<BaseTensorType>(updatedType)) {
newYieldedValue = newYieldedValue =
rewriter.create<TensorStaticInfoCastOp>(loc, updatedType, def); rewriter.create<TensorStaticInfoCastOp>(loc, updatedType, def);
} else { } else {

View File

@ -53,8 +53,9 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
op, "Failed to convert `dtypeScalarType` to a builtin type"); op, "Failed to convert `dtypeScalarType` to a builtin type");
} }
impliedTypeFromDtype = impliedTypeFromDtype =
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype( cast<BaseTensorType>(originalResultType)
originalResultType.getOptionalSizes(), *builtinType); .getWithSizesAndDtype(originalResultType.getOptionalSizes(),
*builtinType);
} else { } else {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Unimplemented: Expected result type to " "Unimplemented: Expected result type to "
@ -179,7 +180,7 @@ public:
} }
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType()); Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
auto impliedTypeFromInputType = auto impliedTypeFromInputType =
originalResultType.cast<BaseTensorType>() cast<BaseTensorType>(originalResultType)
.getWithSizesAndDtype(originalResultType.getOptionalSizes(), .getWithSizesAndDtype(originalResultType.getOptionalSizes(),
inputType) inputType)
.cast<BaseTensorType>(); .cast<BaseTensorType>();

View File

@ -98,7 +98,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
auto originalResultType = result.getType().cast<BaseTensorType>(); auto originalResultType = result.getType().cast<BaseTensorType>();
auto impliedTypesFromShape = auto impliedTypesFromShape =
originalResultType.cast<BaseTensorType>() cast<BaseTensorType>(originalResultType)
.getWithSizesAndDtype(ArrayRef(sizes), .getWithSizesAndDtype(ArrayRef(sizes),
originalResultType.getOptionalDtype()) originalResultType.getOptionalDtype())
.cast<BaseTensorType>(); .cast<BaseTensorType>();

View File

@ -70,8 +70,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::QInt8; return torch_upstream::ScalarType::QInt8;
if (type.isa<QInt32Type>()) if (type.isa<QInt32Type>())
return torch_upstream::ScalarType::QInt32; return torch_upstream::ScalarType::QInt32;
if (type.isa<ComplexType>()) { if (isa<ComplexType>(type)) {
mlir::Type complexElemType = type.cast<ComplexType>().getElementType(); mlir::Type complexElemType = cast<ComplexType>(type).getElementType();
if (complexElemType.isF16()) if (complexElemType.isF16())
return torch_upstream::ScalarType::ComplexHalf; return torch_upstream::ScalarType::ComplexHalf;
if (complexElemType.isF32()) if (complexElemType.isF32())
@ -84,9 +84,9 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
Type Torch::getTypeForTorchType( Type Torch::getTypeForTorchType(
MLIRContext *context, Type type, MLIRContext *context, Type type,
mlir::IntegerType::SignednessSemantics signedness) { mlir::IntegerType::SignednessSemantics signedness) {
if (type.isa<Torch::IntType>()) if (isa<Torch::IntType>(type))
return IntegerType::get(context, 64, signedness); return IntegerType::get(context, 64, signedness);
if (type.isa<Torch::FloatType>()) if (isa<Torch::FloatType>(type))
return Float64Type::get(context); return Float64Type::get(context);
llvm::report_fatal_error("unhandled type for getTypeForTorchType"); llvm::report_fatal_error("unhandled type for getTypeForTorchType");
} }
@ -150,14 +150,14 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
Type Torch::getDefaultDtypeForTorchScalar(Type type) { Type Torch::getDefaultDtypeForTorchScalar(Type type) {
MLIRContext *context = type.getContext(); MLIRContext *context = type.getContext();
if (type.isa<Torch::FloatType>()) { if (isa<Torch::FloatType>(type)) {
// For now, use float32 which is the initial default dtype returned by // For now, use float32 which is the initial default dtype returned by
// `torch.get_default_dtype`. // `torch.get_default_dtype`.
return Float32Type::get(context); return Float32Type::get(context);
} }
if (type.isa<Torch::IntType>()) if (isa<Torch::IntType>(type))
return IntegerType::get(context, 64, IntegerType::Signed); return IntegerType::get(context, 64, IntegerType::Signed);
if (type.isa<Torch::BoolType>()) if (isa<Torch::BoolType>(type))
return IntegerType::get(context, 1); return IntegerType::get(context, 1);
llvm_unreachable( llvm_unreachable(
"getDefaultDtypeForTorchScalar called on an unsupported type"); "getDefaultDtypeForTorchScalar called on an unsupported type");
@ -165,11 +165,11 @@ Type Torch::getDefaultDtypeForTorchScalar(Type type) {
Type Torch::getBuiltInTypeForTorchScalar(Type type) { Type Torch::getBuiltInTypeForTorchScalar(Type type) {
MLIRContext *context = type.getContext(); MLIRContext *context = type.getContext();
if (type.isa<Torch::FloatType>()) if (isa<Torch::FloatType>(type))
return Float64Type::get(context); return Float64Type::get(context);
if (type.isa<Torch::IntType>()) if (isa<Torch::IntType>(type))
return IntegerType::get(context, 64, IntegerType::Signed); return IntegerType::get(context, 64, IntegerType::Signed);
if (type.isa<Torch::BoolType>()) if (isa<Torch::BoolType>(type))
return IntegerType::get(context, 1); return IntegerType::get(context, 1);
llvm_unreachable( llvm_unreachable(
"getBuiltInTypeForTorchScalar called on an unsupported type"); "getBuiltInTypeForTorchScalar called on an unsupported type");

View File

@ -62,15 +62,14 @@ Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder,
Attribute value, Attribute value,
Type type, Type type,
Location loc) { Location loc) {
if (auto integerType = type.dyn_cast<Torch::IntType>()) if (auto integerType = dyn_cast<Torch::IntType>(type))
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>()); return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value));
if (auto floatType = type.dyn_cast<Torch::FloatType>()) if (auto floatType = dyn_cast<Torch::FloatType>(type))
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>()); return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value));
if (type.isa<Torch::BoolType>()) { if (isa<Torch::BoolType>(type)) {
return builder.create<Torch::ConstantBoolOp>(loc, return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value));
value.cast<IntegerAttr>());
} }
return arith::ConstantOp::materialize(builder, value, type, loc); return arith::ConstantOp::materialize(builder, value, type, loc);

View File

@ -95,7 +95,7 @@ public:
// get outputs // get outputs
Type newResultType = getTypeConverter()->convertType(op.getType(0)); Type newResultType = getTypeConverter()->convertType(op.getType(0));
auto resultType = newResultType.cast<RankedTensorType>(); auto resultType = cast<RankedTensorType>(newResultType);
if (!resultType) { if (!resultType) {
return failure(); return failure();
} }

View File

@ -33,7 +33,7 @@ class VerifyStablehloBackendContractPass
converter.addConversion([](Type type) -> Type { converter.addConversion([](Type type) -> Type {
auto elemTy = type; auto elemTy = type;
if (isa<TensorType>(type)) if (isa<TensorType>(type))
elemTy = type.cast<TensorType>().getElementType(); elemTy = cast<TensorType>(type).getElementType();
if (BaseMemRefType::isValidElementType(elemTy)) if (BaseMemRefType::isValidElementType(elemTy))
return type; return type;
return nullptr; return nullptr;

View File

@ -54,11 +54,11 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static bool isArgMemRefTypeValid(Type type) { static bool isArgMemRefTypeValid(Type type) {
if (auto memRefType = type.dyn_cast<MemRefType>()) { if (auto memRefType = dyn_cast<MemRefType>(type)) {
Type elemTy = memRefType.getElementType(); Type elemTy = memRefType.getElementType();
if (elemTy.isa<Float16Type, Float32Type, Float64Type>()) { if (elemTy.isa<Float16Type, Float32Type, Float64Type>()) {
return true; return true;
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) { } else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) {
if (integerTy.isSignlessInteger(64)) if (integerTy.isSignlessInteger(64))
return true; return true;
if (integerTy.isSignlessInteger(32)) if (integerTy.isSignlessInteger(32))
@ -69,7 +69,7 @@ static bool isArgMemRefTypeValid(Type type) {
return true; return true;
if (integerTy.isSignlessInteger(1)) if (integerTy.isSignlessInteger(1))
return true; return true;
} else if (auto complexTy = elemTy.dyn_cast<ComplexType>()) { } else if (auto complexTy = dyn_cast<ComplexType>(elemTy)) {
return complexTy.getElementType().isa<Float32Type, Float64Type>(); return complexTy.getElementType().isa<Float32Type, Float64Type>();
} }
} }
@ -81,7 +81,7 @@ static void addEmitCInterfaceAttr(func::FuncOp func) {
} }
static Type getAbiTypeForMemRef(Type type) { static Type getAbiTypeForMemRef(Type type) {
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0); return UnrankedMemRefType::get(cast<MemRefType>(type).getElementType(), 0);
} }
// Helper function to get the type string for one return value like i32, f64, // Helper function to get the type string for one return value like i32, f64,
@ -90,12 +90,12 @@ static Type getAbiTypeForMemRef(Type type) {
static std::string getTypeToken(Type type) { static std::string getTypeToken(Type type) {
if (type.isSignlessInteger()) if (type.isSignlessInteger())
return ("i" + Twine(type.getIntOrFloatBitWidth())).str(); return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
else if (type.isa<mlir::FloatType>()) else if (isa<mlir::FloatType>(type))
return ("f" + Twine(type.getIntOrFloatBitWidth())).str(); return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
else if (auto complexTy = type.dyn_cast<mlir::ComplexType>()) else if (auto complexTy = dyn_cast<mlir::ComplexType>(type))
return ("c" + Twine(complexTy.getElementType().getIntOrFloatBitWidth())) return ("c" + Twine(complexTy.getElementType().getIntOrFloatBitWidth()))
.str(); .str();
else if (auto memRefType = type.dyn_cast<UnrankedMemRefType>()) else if (auto memRefType = dyn_cast<UnrankedMemRefType>(type))
return "mr" + getTypeToken(memRefType.getElementType()); return "mr" + getTypeToken(memRefType.getElementType());
llvm_unreachable( llvm_unreachable(
@ -171,7 +171,7 @@ static LogicalResult mungeFunction(
for (auto en : llvm::enumerate(types)) { for (auto en : llvm::enumerate(types)) {
Type retType = en.value(); Type retType = en.value();
Value retVal = op.getOperand(en.index()); Value retVal = op.getOperand(en.index());
if (auto memrefReturnType = retType.dyn_cast<MemRefType>()) { if (auto memrefReturnType = dyn_cast<MemRefType>(retType)) {
auto elemType = memrefReturnType.getElementType(); auto elemType = memrefReturnType.getElementType();
retType = UnrankedMemRefType::get(elemType, 0); retType = UnrankedMemRefType::get(elemType, 0);
// Cast to unranked memref type before sending it as a function // Cast to unranked memref type before sending it as a function