mirror of https://github.com/llvm/torch-mlir
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130)
We should prefer functional style as the method style is deprecated https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated (https://mlir.llvm.org/deprecation/)pull/3141/head
parent
308c45e61a
commit
d4a30b7e67
|
@ -178,7 +178,7 @@ struct OpBinder {
|
||||||
}
|
}
|
||||||
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
||||||
for (auto element : arrayAttr) {
|
for (auto element : arrayAttr) {
|
||||||
auto integerAttr = element.dyn_cast<IntegerAttr>();
|
auto integerAttr = dyn_cast<IntegerAttr>(element);
|
||||||
if (!integerAttr)
|
if (!integerAttr)
|
||||||
return failure();
|
return failure();
|
||||||
IntegerType t = cast<IntegerType>(integerAttr.getType());
|
IntegerType t = cast<IntegerType>(integerAttr.getType());
|
||||||
|
@ -200,7 +200,7 @@ struct OpBinder {
|
||||||
return success();
|
return success();
|
||||||
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
||||||
for (auto element : arrayAttr) {
|
for (auto element : arrayAttr) {
|
||||||
StringAttr stringAttr = element.dyn_cast<StringAttr>();
|
StringAttr stringAttr = dyn_cast<StringAttr>(element);
|
||||||
if (!stringAttr)
|
if (!stringAttr)
|
||||||
return failure();
|
return failure();
|
||||||
values.push_back(stringAttr.getValue().str());
|
values.push_back(stringAttr.getValue().str());
|
||||||
|
|
|
@ -94,7 +94,7 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
|
||||||
|
|
||||||
// Compute the knowledge based on the inferred type.
|
// Compute the knowledge based on the inferred type.
|
||||||
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
|
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
|
||||||
inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
|
inferredKnowledge.dtype = cast<ShapedType>(result_ty).getElementType();
|
||||||
inferredKnowledge.hasRank = predictedShape.hasRank();
|
inferredKnowledge.hasRank = predictedShape.hasRank();
|
||||||
if (predictedShape.hasRank()) {
|
if (predictedShape.hasRank()) {
|
||||||
for (auto dim : predictedShape.getDims()) {
|
for (auto dim : predictedShape.getDims()) {
|
||||||
|
|
|
@ -1287,7 +1287,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.getLoc(), axisScalar, finalOffset);
|
binder.getLoc(), axisScalar, finalOffset);
|
||||||
|
|
||||||
Torch::BaseTensorType resultTensorType =
|
Torch::BaseTensorType resultTensorType =
|
||||||
resultType.cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(resultType);
|
||||||
if (!resultTensorType.hasDtype()) {
|
if (!resultTensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "expected result type to have a dtype");
|
binder.op, "expected result type to have a dtype");
|
||||||
|
@ -1899,7 +1899,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
|
|
||||||
// If its a dense resource attr we need to convert to a dense type:
|
// If its a dense resource attr we need to convert to a dense type:
|
||||||
if (DenseResourceElementsAttr rattr =
|
if (DenseResourceElementsAttr rattr =
|
||||||
attr.dyn_cast_or_null<DenseResourceElementsAttr>()) {
|
dyn_cast_or_null<DenseResourceElementsAttr>(attr)) {
|
||||||
// Bytes are stored in little endian order. Big endian support will
|
// Bytes are stored in little endian order. Big endian support will
|
||||||
// require swizzling.
|
// require swizzling.
|
||||||
if (!Endian::little) {
|
if (!Endian::little) {
|
||||||
|
@ -1916,7 +1916,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
|
|
||||||
Attribute splattr;
|
Attribute splattr;
|
||||||
if (isa<SplatElementsAttr>(attr)) {
|
if (isa<SplatElementsAttr>(attr)) {
|
||||||
auto denseAttr = attr.cast<DenseElementsAttr>();
|
auto denseAttr = cast<DenseElementsAttr>(attr);
|
||||||
splattr = denseAttr.getSplatValue<Attribute>();
|
splattr = denseAttr.getSplatValue<Attribute>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1366,7 +1366,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
// set the splitted axis to variable shape
|
// set the splitted axis to variable shape
|
||||||
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
|
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
|
||||||
for (auto result : binder.op->getResultTypes()) {
|
for (auto result : binder.op->getResultTypes()) {
|
||||||
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim];
|
int64_t d = cast<Torch::ValueTensorType>(result).getSizes()[dim];
|
||||||
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
|
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1437,7 +1437,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
|
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
|
||||||
for (auto result : binder.op->getResultTypes()) {
|
for (auto result : binder.op->getResultTypes()) {
|
||||||
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim];
|
int64_t d = cast<Torch::ValueTensorType>(result).getSizes()[dim];
|
||||||
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
|
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -272,9 +272,9 @@ public:
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType);
|
convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType);
|
||||||
Value operandB =
|
Value operandB =
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType);
|
convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType);
|
||||||
if (resultType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(resultType)) {
|
||||||
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB);
|
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB);
|
||||||
} else if (resultType.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(resultType)) {
|
||||||
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB);
|
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB);
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
|
|
@ -1881,7 +1881,7 @@ public:
|
||||||
|
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||||
auto inputElementType = getElementTypeOrSelf(input.getType());
|
auto inputElementType = getElementTypeOrSelf(input.getType());
|
||||||
if (!inputElementType.isa<ComplexType>()) {
|
if (!isa<ComplexType>(inputElementType)) {
|
||||||
return op.emitError("only ComplexType is allowed as input type");
|
return op.emitError("only ComplexType is allowed as input type");
|
||||||
}
|
}
|
||||||
Type elementType = resultType.getElementType();
|
Type elementType = resultType.getElementType();
|
||||||
|
|
|
@ -131,7 +131,7 @@ public:
|
||||||
auto resultTy = op.getType().cast<ValueTensorType>();
|
auto resultTy = op.getType().cast<ValueTensorType>();
|
||||||
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
Type elementType = newResultType.cast<TensorType>().getElementType();
|
Type elementType = cast<TensorType>(newResultType).getElementType();
|
||||||
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy);
|
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy);
|
||||||
if (accumulatorDType != resultDTy) {
|
if (accumulatorDType != resultDTy) {
|
||||||
elementType = accumulatorDType;
|
elementType = accumulatorDType;
|
||||||
|
@ -201,7 +201,7 @@ public:
|
||||||
|
|
||||||
if (accumulatorDType != resultDTy) {
|
if (accumulatorDType != resultDTy) {
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
newResultType.cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(newResultType).getElementType();
|
||||||
matmul = torch_to_linalg::convertTensorToElementType(
|
matmul = torch_to_linalg::convertTensorToElementType(
|
||||||
rewriter, loc, matmul, resultElementType);
|
rewriter, loc, matmul, resultElementType);
|
||||||
}
|
}
|
||||||
|
@ -307,7 +307,7 @@ public:
|
||||||
unsigned rhsRank = rhsType.getRank();
|
unsigned rhsRank = rhsType.getRank();
|
||||||
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
auto resultType = newResultType.cast<RankedTensorType>();
|
auto resultType = cast<RankedTensorType>(newResultType);
|
||||||
Type elementType = resultType.getElementType();
|
Type elementType = resultType.getElementType();
|
||||||
|
|
||||||
// The different cases of torch_matmul op is mentioned here:
|
// The different cases of torch_matmul op is mentioned here:
|
||||||
|
@ -600,9 +600,9 @@ public:
|
||||||
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
newResultType.cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(newResultType).getElementType();
|
||||||
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType();
|
Type lhsElementType = cast<RankedTensorType>(lhsType).getElementType();
|
||||||
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType();
|
Type rhsElementType = cast<RankedTensorType>(rhsType).getElementType();
|
||||||
|
|
||||||
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
|
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -712,9 +712,9 @@ public:
|
||||||
auto weightDTy = weight.getType().cast<RankedTensorType>().getElementType();
|
auto weightDTy = weight.getType().cast<RankedTensorType>().getElementType();
|
||||||
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
||||||
|
|
||||||
if (!inputDTy.isa<mlir::FloatType, mlir::IntegerType>() ||
|
if (!isa<mlir::FloatType, mlir::IntegerType>(inputDTy) ||
|
||||||
!weightDTy.isa<mlir::FloatType, mlir::IntegerType>() ||
|
!isa<mlir::FloatType, mlir::IntegerType>(weightDTy) ||
|
||||||
!resultDTy.isa<mlir::FloatType, mlir::IntegerType>())
|
!isa<mlir::FloatType, mlir::IntegerType>(resultDTy))
|
||||||
return op.emitError("unimplemented: non-fp not-int type");
|
return op.emitError("unimplemented: non-fp not-int type");
|
||||||
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
|
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
|
||||||
size_t numSpatialDims = inRank - 2;
|
size_t numSpatialDims = inRank - 2;
|
||||||
|
@ -790,9 +790,8 @@ public:
|
||||||
SmallVector<Value> outDims{inBatch, weightBatch};
|
SmallVector<Value> outDims{inBatch, weightBatch};
|
||||||
Value paddedInput;
|
Value paddedInput;
|
||||||
if (transposed) {
|
if (transposed) {
|
||||||
if (!inputDTy.isa<mlir::FloatType>() ||
|
if (!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy) ||
|
||||||
!weightDTy.isa<mlir::FloatType>() ||
|
!isa<mlir::FloatType>(resultDTy))
|
||||||
!resultDTy.isa<mlir::FloatType>())
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "transpose does not support non-fp type yet");
|
op, "transpose does not support non-fp type yet");
|
||||||
|
|
||||||
|
@ -927,10 +926,10 @@ public:
|
||||||
accumulatorDType);
|
accumulatorDType);
|
||||||
if (bias.getType().isa<Torch::NoneType>()) {
|
if (bias.getType().isa<Torch::NoneType>()) {
|
||||||
Value c0;
|
Value c0;
|
||||||
if (accumulatorDType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(accumulatorDType)) {
|
||||||
c0 = rewriter.create<arith::ConstantOp>(
|
c0 = rewriter.create<arith::ConstantOp>(
|
||||||
loc, FloatAttr::get(accumulatorDType, 0.0));
|
loc, FloatAttr::get(accumulatorDType, 0.0));
|
||||||
} else if (accumulatorDType.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(accumulatorDType)) {
|
||||||
c0 = rewriter.create<arith::ConstantOp>(
|
c0 = rewriter.create<arith::ConstantOp>(
|
||||||
loc, IntegerAttr::get(accumulatorDType, 0));
|
loc, IntegerAttr::get(accumulatorDType, 0));
|
||||||
}
|
}
|
||||||
|
@ -1021,7 +1020,7 @@ public:
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
if (accumulatorDType != resultDTy) {
|
if (accumulatorDType != resultDTy) {
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
newResultType.cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(newResultType).getElementType();
|
||||||
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
||||||
resultElementType);
|
resultElementType);
|
||||||
}
|
}
|
||||||
|
@ -1081,7 +1080,7 @@ public:
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
if (accumulatorDType != resultDTy) {
|
if (accumulatorDType != resultDTy) {
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
newResultType.cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(newResultType).getElementType();
|
||||||
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
||||||
resultElementType);
|
resultElementType);
|
||||||
}
|
}
|
||||||
|
@ -1125,7 +1124,7 @@ public:
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
if (accumulatorDType != resultDTy) {
|
if (accumulatorDType != resultDTy) {
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
newResultType.cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(newResultType).getElementType();
|
||||||
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
||||||
resultElementType);
|
resultElementType);
|
||||||
}
|
}
|
||||||
|
@ -1203,7 +1202,7 @@ public:
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
if (accumulatorDType != resultDTy) {
|
if (accumulatorDType != resultDTy) {
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
newResultType.cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(newResultType).getElementType();
|
||||||
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
||||||
resultElementType);
|
resultElementType);
|
||||||
}
|
}
|
||||||
|
|
|
@ -154,7 +154,7 @@ static LogicalResult createPoolingOp(
|
||||||
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
|
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||||
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput)
|
if (!isa<mlir::FloatType>(elementType) && !supportNonFPInput)
|
||||||
return op->emitError("unimplemented: non-floating point type");
|
return op->emitError("unimplemented: non-floating point type");
|
||||||
|
|
||||||
Value initValue =
|
Value initValue =
|
||||||
|
@ -217,7 +217,7 @@ private:
|
||||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||||
/*Negative=*/true));
|
/*Negative=*/true));
|
||||||
Value initValue =
|
Value initValue =
|
||||||
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
|
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
|
||||||
|
@ -335,7 +335,7 @@ public:
|
||||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
APFloat::getInf(
|
APFloat::getInf(
|
||||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||||
/*Negative=*/true));
|
/*Negative=*/true));
|
||||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||||
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
||||||
|
@ -416,7 +416,7 @@ public:
|
||||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
||||||
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||||
/*Negative=*/true));
|
/*Negative=*/true));
|
||||||
Value maxPool2d, paddedInput;
|
Value maxPool2d, paddedInput;
|
||||||
SmallVector<Value, 4> outTensorShape;
|
SmallVector<Value, 4> outTensorShape;
|
||||||
|
@ -555,7 +555,7 @@ public:
|
||||||
self.getType().cast<RankedTensorType>().getElementType();
|
self.getType().cast<RankedTensorType>().getElementType();
|
||||||
Type resultType = typeConverter->convertType(op.getType());
|
Type resultType = typeConverter->convertType(op.getType());
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
resultType.cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(resultType).getElementType();
|
||||||
|
|
||||||
bool ceilMode;
|
bool ceilMode;
|
||||||
SmallVector<Value, Dim> kernelSizeIntValues;
|
SmallVector<Value, Dim> kernelSizeIntValues;
|
||||||
|
@ -615,9 +615,9 @@ public:
|
||||||
/*iteratorTypes=*/iteratorTypesAvg,
|
/*iteratorTypes=*/iteratorTypesAvg,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
Value avg;
|
Value avg;
|
||||||
if (resultElementType.isa<mlir::IntegerType>())
|
if (isa<mlir::IntegerType>(resultElementType))
|
||||||
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
|
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
|
||||||
else if (resultElementType.isa<mlir::FloatType>())
|
else if (isa<mlir::FloatType>(resultElementType))
|
||||||
avg = b.create<arith::DivFOp>(loc, args[0], divisor);
|
avg = b.create<arith::DivFOp>(loc, args[0], divisor);
|
||||||
b.create<linalg::YieldOp>(loc, avg);
|
b.create<linalg::YieldOp>(loc, avg);
|
||||||
})
|
})
|
||||||
|
@ -707,7 +707,7 @@ public:
|
||||||
Type auxTensorElementType = auxTensorType.getElementType();
|
Type auxTensorElementType = auxTensorType.getElementType();
|
||||||
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||||
/*Negative=*/true));
|
/*Negative=*/true));
|
||||||
buffVal = rewriter.create<arith::ConstantOp>(loc, elementType,
|
buffVal = rewriter.create<arith::ConstantOp>(loc, elementType,
|
||||||
smallestFPValueAttr);
|
smallestFPValueAttr);
|
||||||
|
|
|
@ -130,7 +130,7 @@ public:
|
||||||
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
|
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
|
||||||
Type elemTy = resultType.getElementType();
|
Type elemTy = resultType.getElementType();
|
||||||
|
|
||||||
if (!elemTy.isa<mlir::FloatType>())
|
if (!isa<mlir::FloatType>(elemTy))
|
||||||
return rewriter.notifyMatchFailure(op, "This op only support float type");
|
return rewriter.notifyMatchFailure(op, "This op only support float type");
|
||||||
|
|
||||||
if (!generator.getType().isa<Torch::NoneType>())
|
if (!generator.getType().isa<Torch::NoneType>())
|
||||||
|
|
|
@ -70,7 +70,7 @@ public:
|
||||||
input.getType().template cast<RankedTensorType>();
|
input.getType().template cast<RankedTensorType>();
|
||||||
Type idxElementType =
|
Type idxElementType =
|
||||||
getElementTypeOrSelf(typec->convertType(idxResultType));
|
getElementTypeOrSelf(typec->convertType(idxResultType));
|
||||||
if (!idxElementType.isa<IntegerType>())
|
if (!isa<IntegerType>(idxElementType))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, opName + " to linalg.* requires integer-like result type");
|
op, opName + " to linalg.* requires integer-like result type");
|
||||||
|
|
||||||
|
@ -89,8 +89,8 @@ public:
|
||||||
|
|
||||||
Type inElementType = inputType.getElementType();
|
Type inElementType = inputType.getElementType();
|
||||||
bool isUnsigned = false;
|
bool isUnsigned = false;
|
||||||
if (!inElementType.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(inElementType)) {
|
||||||
if (inElementType.isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(inElementType)) {
|
||||||
auto integerTy = op.getSelf()
|
auto integerTy = op.getSelf()
|
||||||
.getType()
|
.getType()
|
||||||
.template cast<BaseTensorType>()
|
.template cast<BaseTensorType>()
|
||||||
|
@ -121,22 +121,21 @@ public:
|
||||||
loc, getAsOpFoldResult(resultShape), inElementType);
|
loc, getAsOpFoldResult(resultShape), inElementType);
|
||||||
|
|
||||||
Value fillValue;
|
Value fillValue;
|
||||||
if (inElementType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(inElementType)) {
|
||||||
fillValue = rewriter.create<arith::ConstantOp>(
|
fillValue = rewriter.create<arith::ConstantOp>(
|
||||||
loc,
|
loc, rewriter.getFloatAttr(
|
||||||
rewriter.getFloatAttr(
|
|
||||||
inElementType,
|
inElementType,
|
||||||
APFloat::getInf(
|
APFloat::getInf(
|
||||||
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
|
cast<mlir::FloatType>(inElementType).getFloatSemantics(),
|
||||||
/*Negative=*/isMax)));
|
/*Negative=*/isMax)));
|
||||||
} else if (!isUnsigned) {
|
} else if (!isUnsigned) {
|
||||||
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
|
auto width = cast<mlir::IntegerType>(inElementType).getWidth();
|
||||||
auto init = isMax ? APSInt::getSignedMinValue(width)
|
auto init = isMax ? APSInt::getSignedMinValue(width)
|
||||||
: APSInt::getSignedMaxValue(width);
|
: APSInt::getSignedMaxValue(width);
|
||||||
fillValue = rewriter.create<arith::ConstantOp>(
|
fillValue = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(inElementType, init));
|
loc, rewriter.getIntegerAttr(inElementType, init));
|
||||||
} else if (isUnsigned) {
|
} else if (isUnsigned) {
|
||||||
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
|
auto width = cast<mlir::IntegerType>(inElementType).getWidth();
|
||||||
auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width);
|
auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width);
|
||||||
fillValue = rewriter.create<arith::ConstantOp>(
|
fillValue = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(inElementType, init));
|
loc, rewriter.getIntegerAttr(inElementType, init));
|
||||||
|
@ -180,7 +179,7 @@ public:
|
||||||
rewriter.create<linalg::IndexOp>(loc, dim));
|
rewriter.create<linalg::IndexOp>(loc, dim));
|
||||||
|
|
||||||
Value resultVal, predicate;
|
Value resultVal, predicate;
|
||||||
if (inElementType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(inElementType)) {
|
||||||
arith::CmpFPredicate predType;
|
arith::CmpFPredicate predType;
|
||||||
if (isMax) {
|
if (isMax) {
|
||||||
predType = arith::CmpFPredicate::OGT;
|
predType = arith::CmpFPredicate::OGT;
|
||||||
|
@ -300,21 +299,21 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
||||||
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||||
|
|
||||||
if (isa<AtenProdDimIntOp>(op)) {
|
if (isa<AtenProdDimIntOp>(op)) {
|
||||||
if (elementType.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(elementType))
|
||||||
return b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 1.0));
|
return b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 1.0));
|
||||||
else if (elementType.isa<mlir::IntegerType>())
|
else if (isa<mlir::IntegerType>(elementType))
|
||||||
return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(elementType, 1));
|
return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(elementType, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<AtenMaxOp>(op)) {
|
if (isa<AtenMaxOp>(op)) {
|
||||||
if (elementType.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(elementType))
|
||||||
return b.create<arith::ConstantOp>(
|
return b.create<arith::ConstantOp>(
|
||||||
loc, b.getFloatAttr(
|
loc, b.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
APFloat::getInf(
|
APFloat::getInf(
|
||||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||||
/*Negative=*/true)));
|
/*Negative=*/true)));
|
||||||
else if (elementType.isa<mlir::IntegerType>() &&
|
else if (isa<mlir::IntegerType>(elementType) &&
|
||||||
elementType.getIntOrFloatBitWidth() != 8)
|
elementType.getIntOrFloatBitWidth() != 8)
|
||||||
return b.create<arith::ConstantOp>(
|
return b.create<arith::ConstantOp>(
|
||||||
loc, b.getIntegerAttr(elementType,
|
loc, b.getIntegerAttr(elementType,
|
||||||
|
@ -323,14 +322,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<AtenMinOp>(op)) {
|
if (isa<AtenMinOp>(op)) {
|
||||||
if (elementType.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(elementType))
|
||||||
return b.create<arith::ConstantOp>(
|
return b.create<arith::ConstantOp>(
|
||||||
loc, b.getFloatAttr(
|
loc, b.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
APFloat::getInf(
|
APFloat::getInf(
|
||||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||||
/*Negative=*/false)));
|
/*Negative=*/false)));
|
||||||
else if (elementType.isa<mlir::IntegerType>() &&
|
else if (isa<mlir::IntegerType>(elementType) &&
|
||||||
elementType.getIntOrFloatBitWidth() != 8)
|
elementType.getIntOrFloatBitWidth() != 8)
|
||||||
return b.create<arith::ConstantOp>(
|
return b.create<arith::ConstantOp>(
|
||||||
loc, b.getIntegerAttr(elementType,
|
loc, b.getIntegerAttr(elementType,
|
||||||
|
@ -359,25 +358,25 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
||||||
Value self =
|
Value self =
|
||||||
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||||
Value result = payloadArgs[1];
|
Value result = payloadArgs[1];
|
||||||
if (resultElementType.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(resultElementType))
|
||||||
return b.create<arith::AddFOp>(loc, self, result);
|
return b.create<arith::AddFOp>(loc, self, result);
|
||||||
else if (resultElementType.isa<mlir::IntegerType>())
|
else if (isa<mlir::IntegerType>(resultElementType))
|
||||||
return b.create<arith::AddIOp>(loc, self, result);
|
return b.create<arith::AddIOp>(loc, self, result);
|
||||||
} else if (isa<AtenProdDimIntOp>(op)) {
|
} else if (isa<AtenProdDimIntOp>(op)) {
|
||||||
Value self =
|
Value self =
|
||||||
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||||
Value result = payloadArgs[1];
|
Value result = payloadArgs[1];
|
||||||
if (resultElementType.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(resultElementType))
|
||||||
return b.create<arith::MulFOp>(loc, self, result);
|
return b.create<arith::MulFOp>(loc, self, result);
|
||||||
else if (resultElementType.isa<mlir::IntegerType>())
|
else if (isa<mlir::IntegerType>(resultElementType))
|
||||||
return b.create<arith::MulIOp>(loc, self, result);
|
return b.create<arith::MulIOp>(loc, self, result);
|
||||||
} else if (auto max = dyn_cast<AtenMaxOp>(op)) {
|
} else if (auto max = dyn_cast<AtenMaxOp>(op)) {
|
||||||
Value self =
|
Value self =
|
||||||
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||||
Value result = payloadArgs[1];
|
Value result = payloadArgs[1];
|
||||||
if (resultElementType.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(resultElementType))
|
||||||
return b.create<arith::MaximumFOp>(loc, self, result);
|
return b.create<arith::MaximumFOp>(loc, self, result);
|
||||||
else if (resultElementType.isa<mlir::IntegerType>()) {
|
else if (isa<mlir::IntegerType>(resultElementType)) {
|
||||||
IntegerType intType = max.getSelf()
|
IntegerType intType = max.getSelf()
|
||||||
.getType()
|
.getType()
|
||||||
.cast<BaseTensorType>()
|
.cast<BaseTensorType>()
|
||||||
|
@ -392,9 +391,9 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
||||||
Value self =
|
Value self =
|
||||||
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||||
Value result = payloadArgs[1];
|
Value result = payloadArgs[1];
|
||||||
if (resultElementType.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(resultElementType))
|
||||||
return b.create<arith::MinimumFOp>(loc, self, result);
|
return b.create<arith::MinimumFOp>(loc, self, result);
|
||||||
else if (resultElementType.isa<mlir::IntegerType>()) {
|
else if (isa<mlir::IntegerType>(resultElementType)) {
|
||||||
IntegerType intType = min.getSelf()
|
IntegerType intType = min.getSelf()
|
||||||
.getType()
|
.getType()
|
||||||
.cast<BaseTensorType>()
|
.cast<BaseTensorType>()
|
||||||
|
@ -626,10 +625,10 @@ private:
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
|
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
|
||||||
isa<AtenNormScalarOp>(op)) &&
|
isa<AtenNormScalarOp>(op)) &&
|
||||||
!elemType.isa<mlir::FloatType>())
|
!isa<mlir::FloatType>(elemType))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only float types are valid for vector norm ops");
|
op, "only float types are valid for vector norm ops");
|
||||||
if (isa<AtenAllDimOp>(op) && elemType.isa<mlir::IntegerType>() &&
|
if (isa<AtenAllDimOp>(op) && isa<mlir::IntegerType>(elemType) &&
|
||||||
elemType.getIntOrFloatBitWidth() == 8)
|
elemType.getIntOrFloatBitWidth() == 8)
|
||||||
return rewriter.notifyMatchFailure(op, "uint8 is not supported");
|
return rewriter.notifyMatchFailure(op, "uint8 is not supported");
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
Type elementType = newResultType.cast<RankedTensorType>().getElementType();
|
Type elementType = cast<RankedTensorType>(newResultType).getElementType();
|
||||||
Value castedValue =
|
Value castedValue =
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
|
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
|
||||||
|
|
||||||
|
@ -553,7 +553,7 @@ public:
|
||||||
// The size of the result is calculated as follows:
|
// The size of the result is calculated as follows:
|
||||||
// ceil((end - start)/step)
|
// ceil((end - start)/step)
|
||||||
Value resultShape;
|
Value resultShape;
|
||||||
if (dtype.isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(dtype)) {
|
||||||
Value subOut = rewriter.create<arith::SubIOp>(loc, end, start);
|
Value subOut = rewriter.create<arith::SubIOp>(loc, end, start);
|
||||||
resultShape = rewriter.create<arith::CeilDivSIOp>(loc, subOut, step);
|
resultShape = rewriter.create<arith::CeilDivSIOp>(loc, subOut, step);
|
||||||
} else {
|
} else {
|
||||||
|
@ -585,7 +585,7 @@ public:
|
||||||
index = castIndexToInt64(b, loc, index);
|
index = castIndexToInt64(b, loc, index);
|
||||||
index = convertScalarToDtype(b, loc, index, dtype);
|
index = convertScalarToDtype(b, loc, index, dtype);
|
||||||
Value mulOut, result;
|
Value mulOut, result;
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
mulOut = b.create<arith::MulFOp>(loc, step, index);
|
mulOut = b.create<arith::MulFOp>(loc, step, index);
|
||||||
result = b.create<arith::AddFOp>(loc, start, mulOut);
|
result = b.create<arith::AddFOp>(loc, start, mulOut);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -35,16 +35,16 @@ using namespace mlir::torch::Torch;
|
||||||
template <typename elementType> static bool hasElementType(Value tensor) {
|
template <typename elementType> static bool hasElementType(Value tensor) {
|
||||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||||
Type tensorElementType = tensorType.getElementType();
|
Type tensorElementType = tensorType.getElementType();
|
||||||
return tensorElementType.isa<elementType>();
|
return isa<elementType>(tensorElementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred,
|
template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred,
|
||||||
arith::CmpIPredicate ispred>
|
arith::CmpIPredicate ispred>
|
||||||
static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
|
static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
|
||||||
Value lhs, Value rhs) {
|
Value lhs, Value rhs) {
|
||||||
if (type.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(type))
|
||||||
return b.create<arith::CmpFOp>(loc, fpred, lhs, rhs);
|
return b.create<arith::CmpFOp>(loc, fpred, lhs, rhs);
|
||||||
if (IntegerType intType = type.dyn_cast<mlir::IntegerType>()) {
|
if (IntegerType intType = dyn_cast<mlir::IntegerType>(type)) {
|
||||||
if (intType.isUnsigned())
|
if (intType.isUnsigned())
|
||||||
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
|
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
|
||||||
if (intType.isSigned())
|
if (intType.isSigned())
|
||||||
|
@ -319,7 +319,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(bitwiseAndScalar.getType())
|
Type dtype = converter->convertType(bitwiseAndScalar.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::IntegerType>()) {
|
if (!isa<mlir::IntegerType>(dtype)) {
|
||||||
bitwiseAndScalar.emitError(
|
bitwiseAndScalar.emitError(
|
||||||
"bitwise_and.Scalar does not support non-integer input dtype.");
|
"bitwise_and.Scalar does not support non-integer input dtype.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -371,7 +371,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType())
|
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::IntegerType>()) {
|
if (!isa<mlir::IntegerType>(dtype)) {
|
||||||
bitwiseRightShiftTensor.emitError(
|
bitwiseRightShiftTensor.emitError(
|
||||||
"Bitwise_Right_Shift op does not support non-integer input dtype.");
|
"Bitwise_Right_Shift op does not support non-integer input dtype.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -385,7 +385,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
|
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::IntegerType>()) {
|
if (!isa<mlir::IntegerType>(dtype)) {
|
||||||
bitwiseLeftShiftTensor.emitError(
|
bitwiseLeftShiftTensor.emitError(
|
||||||
"Bitwise_Left_Shift op does not support non-integer input dtype.");
|
"Bitwise_Left_Shift op does not support non-integer input dtype.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -623,7 +623,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
|
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
|
||||||
/*srcOriginalDtype=*/std::nullopt,
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
/*dstOriginalDtype=*/resultElementType);
|
/*dstOriginalDtype=*/resultElementType);
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||||
return b.create<arith::AddFOp>(loc, lhs, scaled);
|
return b.create<arith::AddFOp>(loc, lhs, scaled);
|
||||||
} else {
|
} else {
|
||||||
|
@ -647,7 +647,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
/*srcOriginalDtype=*/std::nullopt,
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
/*dstOriginalDtype=*/resultElementType,
|
/*dstOriginalDtype=*/resultElementType,
|
||||||
/*originalScalar=*/sub.getAlpha());
|
/*originalScalar=*/sub.getAlpha());
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||||
return b.create<arith::SubFOp>(loc, lhs, scaled);
|
return b.create<arith::SubFOp>(loc, lhs, scaled);
|
||||||
} else {
|
} else {
|
||||||
|
@ -664,10 +664,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value alpha = convertScalarToDtype(
|
Value alpha = convertScalarToDtype(
|
||||||
b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(),
|
b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(),
|
||||||
/*dstOriginalDtype=*/dtype);
|
/*dstOriginalDtype=*/dtype);
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
|
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
|
||||||
return b.create<arith::SubFOp>(loc, self, mult);
|
return b.create<arith::SubFOp>(loc, self, mult);
|
||||||
} else if (dtype.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(dtype)) {
|
||||||
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
|
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
|
||||||
return b.create<arith::SubIOp>(loc, self, mult);
|
return b.create<arith::SubIOp>(loc, self, mult);
|
||||||
}
|
}
|
||||||
|
@ -690,10 +690,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype,
|
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype,
|
||||||
/*srcOriginalDtype=*/std::nullopt,
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
/*dstOriginalDtype=*/resultElementType);
|
/*dstOriginalDtype=*/resultElementType);
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
|
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
|
||||||
return b.create<arith::AddFOp>(loc, self, mult);
|
return b.create<arith::AddFOp>(loc, self, mult);
|
||||||
} else if (dtype.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(dtype)) {
|
||||||
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
|
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
|
||||||
return b.create<arith::AddIOp>(loc, self, mult);
|
return b.create<arith::AddIOp>(loc, self, mult);
|
||||||
}
|
}
|
||||||
|
@ -708,9 +708,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
||||||
} else if (dtype.isa<mlir::ComplexType>()) {
|
} else if (isa<mlir::ComplexType>(dtype)) {
|
||||||
return b.create<complex::MulOp>(loc, lhs, rhs);
|
return b.create<complex::MulOp>(loc, lhs, rhs);
|
||||||
} else {
|
} else {
|
||||||
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
||||||
|
@ -720,7 +720,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(atan2.getType())
|
Type dtype = converter->convertType(atan2.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
atan2.emitError("Atan2 requires floating point result type");
|
atan2.emitError("Atan2 requires floating point result type");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -759,9 +759,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(dtype))
|
||||||
return b.create<arith::DivFOp>(loc, lhs, rhs);
|
return b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||||
else if (dtype.isa<mlir::IntegerType>()) {
|
else if (isa<mlir::IntegerType>(dtype)) {
|
||||||
if (dtype.isUnsignedInteger())
|
if (dtype.isUnsignedInteger())
|
||||||
return b.create<arith::DivUIOp>(loc, lhs, rhs);
|
return b.create<arith::DivUIOp>(loc, lhs, rhs);
|
||||||
return b.create<arith::DivSIOp>(loc, lhs, rhs);
|
return b.create<arith::DivSIOp>(loc, lhs, rhs);
|
||||||
|
@ -777,7 +777,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
Value div;
|
Value div;
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(dtype))
|
||||||
div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||||
else {
|
else {
|
||||||
if (dtype.isUnsignedInteger())
|
if (dtype.isUnsignedInteger())
|
||||||
|
@ -798,7 +798,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
if (roundingMode == "trunc") {
|
if (roundingMode == "trunc") {
|
||||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||||
// to C-style integer division.
|
// to C-style integer division.
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
Value ceil = b.create<math::CeilOp>(loc, div);
|
Value ceil = b.create<math::CeilOp>(loc, div);
|
||||||
Value floor = b.create<math::FloorOp>(loc, div);
|
Value floor = b.create<math::FloorOp>(loc, div);
|
||||||
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||||
|
@ -811,7 +811,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
if (roundingMode == "floor") {
|
if (roundingMode == "floor") {
|
||||||
// "floor" - rounds the results of the division down. Equivalent to
|
// "floor" - rounds the results of the division down. Equivalent to
|
||||||
// floor division in Python (the // operator)
|
// floor division in Python (the // operator)
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(dtype))
|
||||||
return b.create<math::FloorOp>(loc, div);
|
return b.create<math::FloorOp>(loc, div);
|
||||||
else if (!dtype.isUnsignedInteger()) {
|
else if (!dtype.isUnsignedInteger()) {
|
||||||
Type defaultIntToFloatType = b.getF64Type();
|
Type defaultIntToFloatType = b.getF64Type();
|
||||||
|
@ -831,7 +831,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
|
|
||||||
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
|
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
|
||||||
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
|
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
|
||||||
if (!dtype.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
pow.emitError("unimplemented: non-floating point dtype");
|
pow.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -857,7 +857,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(pow.getType())
|
Type dtype = converter->convertType(pow.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
pow.emitError("unimplemented: non-floating point dtype");
|
pow.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -870,7 +870,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(imag.getType())
|
Type dtype = converter->convertType(imag.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
imag.emitError("unimplemented: non-floating point dtype");
|
imag.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -882,7 +882,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(real.getType())
|
Type dtype = converter->convertType(real.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
real.emitError("unimplemented: non-floating point dtype");
|
real.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -898,10 +898,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(dtype))
|
||||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||||
payloadArgs[0], otherPromoted);
|
payloadArgs[0], otherPromoted);
|
||||||
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||||
// TODO: Promote tensor args from integer to float.
|
// TODO: Promote tensor args from integer to float.
|
||||||
gtScalar.emitError(
|
gtScalar.emitError(
|
||||||
|
@ -928,10 +928,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(dtype))
|
||||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE,
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE,
|
||||||
payloadArgs[0], otherPromoted);
|
payloadArgs[0], otherPromoted);
|
||||||
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||||
// TODO: Promote tensor args from integer to float.
|
// TODO: Promote tensor args from integer to float.
|
||||||
geScalar.emitError(
|
geScalar.emitError(
|
||||||
|
@ -955,7 +955,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
if (dtype.isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(dtype)) {
|
||||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||||
// TODO: Promote tensor operand from integer to float.
|
// TODO: Promote tensor operand from integer to float.
|
||||||
eqScalar.emitError(
|
eqScalar.emitError(
|
||||||
|
@ -971,7 +971,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
if (dtype.isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(dtype)) {
|
||||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||||
// TODO: Promote tensor operand from integer to float.
|
// TODO: Promote tensor operand from integer to float.
|
||||||
neScalar.emitError(
|
neScalar.emitError(
|
||||||
|
@ -989,10 +989,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
|
|
||||||
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
|
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
|
||||||
// a lot of code that can be refactored.
|
// a lot of code that can be refactored.
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(dtype))
|
||||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||||
payloadArgs[0], otherPromoted);
|
payloadArgs[0], otherPromoted);
|
||||||
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||||
// TODO: Promote tensor operand from integer to float.
|
// TODO: Promote tensor operand from integer to float.
|
||||||
ltScalar.emitError(
|
ltScalar.emitError(
|
||||||
|
@ -1017,10 +1017,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
|
|
||||||
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
|
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
|
||||||
// that can be refactored.
|
// that can be refactored.
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(dtype))
|
||||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
|
||||||
payloadArgs[0], otherPromoted);
|
payloadArgs[0], otherPromoted);
|
||||||
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||||
// TODO: Promote tensor operand from integer to float.
|
// TODO: Promote tensor operand from integer to float.
|
||||||
leScalar.emitError(
|
leScalar.emitError(
|
||||||
|
@ -1096,14 +1096,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(clamp.getType())
|
Type dtype = converter->convertType(clamp.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::FloatType, mlir::IntegerType>()) {
|
if (!isa<mlir::FloatType, mlir::IntegerType>(dtype)) {
|
||||||
clamp.emitError("unimplement type for clamp");
|
clamp.emitError("unimplement type for clamp");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype();
|
Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype();
|
||||||
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
|
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
|
||||||
if (auto intTy = dstOriginalDtype.dyn_cast<IntegerType>()) {
|
if (auto intTy = dyn_cast<IntegerType>(dstOriginalDtype)) {
|
||||||
isUnsigned = intTy.isUnsigned();
|
isUnsigned = intTy.isUnsigned();
|
||||||
}
|
}
|
||||||
auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
|
auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
|
||||||
|
@ -1112,11 +1112,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
/*dstOriginalDtype=*/dstOriginalDtype);
|
/*dstOriginalDtype=*/dstOriginalDtype);
|
||||||
|
|
||||||
Value pred;
|
Value pred;
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
auto cmp =
|
auto cmp =
|
||||||
getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
|
getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
|
||||||
pred = b.create<arith::CmpFOp>(loc, cmp, input, clamp);
|
pred = b.create<arith::CmpFOp>(loc, cmp, input, clamp);
|
||||||
} else if (dtype.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(dtype)) {
|
||||||
auto cmp =
|
auto cmp =
|
||||||
isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
|
isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
|
||||||
if (getMax)
|
if (getMax)
|
||||||
|
@ -1151,10 +1151,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
isMinNone = false;
|
isMinNone = false;
|
||||||
auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
Value pred;
|
Value pred;
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, result,
|
pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, result,
|
||||||
minPromoted);
|
minPromoted);
|
||||||
} else if (dtype.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(dtype)) {
|
||||||
pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, result,
|
pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, result,
|
||||||
minPromoted);
|
minPromoted);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1169,10 +1169,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
max = isMinNone ? payloadArgs[1] : payloadArgs[2];
|
max = isMinNone ? payloadArgs[1] : payloadArgs[2];
|
||||||
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
||||||
Value pred;
|
Value pred;
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, result,
|
pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, result,
|
||||||
maxPromoted);
|
maxPromoted);
|
||||||
} else if (dtype.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(dtype)) {
|
||||||
pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, result,
|
pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, result,
|
||||||
maxPromoted);
|
maxPromoted);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1194,10 +1194,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value alpha = convertScalarToDtype(
|
Value alpha = convertScalarToDtype(
|
||||||
b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(),
|
b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(),
|
||||||
/*dstOriginalDtype=*/dtype);
|
/*dstOriginalDtype=*/dtype);
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
|
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
|
||||||
return b.create<arith::SubFOp>(loc, other, mult);
|
return b.create<arith::SubFOp>(loc, other, mult);
|
||||||
} else if (dtype.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(dtype)) {
|
||||||
Value mult = b.create<arith::MulIOp>(loc, self, alpha);
|
Value mult = b.create<arith::MulIOp>(loc, self, alpha);
|
||||||
return b.create<arith::SubIOp>(loc, other, mult);
|
return b.create<arith::SubIOp>(loc, other, mult);
|
||||||
}
|
}
|
||||||
|
@ -1211,9 +1211,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(dtype))
|
||||||
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
||||||
if (dtype.isa<mlir::IntegerType>())
|
if (isa<mlir::IntegerType>(dtype))
|
||||||
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
||||||
mulScalar.emitError("unimplemented: Only integer/float dtype supported");
|
mulScalar.emitError("unimplemented: Only integer/float dtype supported");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1246,7 +1246,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(divScalar.getType())
|
Type dtype = converter->convertType(divScalar.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
divScalar.emitError("unimplemented: non-floating point dtype");
|
divScalar.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1263,9 +1263,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
|
Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
|
||||||
Value result;
|
Value result;
|
||||||
|
|
||||||
if (newResultType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(newResultType)) {
|
||||||
result = b.create<arith::RemFOp>(loc, self, other);
|
result = b.create<arith::RemFOp>(loc, self, other);
|
||||||
} else if (newResultType.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(newResultType)) {
|
||||||
result = b.create<arith::RemSIOp>(loc, self, other);
|
result = b.create<arith::RemSIOp>(loc, self, other);
|
||||||
} else {
|
} else {
|
||||||
remScalar.emitError(
|
remScalar.emitError(
|
||||||
|
@ -1283,9 +1283,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
|
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
|
||||||
Value result;
|
Value result;
|
||||||
|
|
||||||
if (newResultType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(newResultType)) {
|
||||||
result = b.create<arith::RemFOp>(loc, self, other);
|
result = b.create<arith::RemFOp>(loc, self, other);
|
||||||
} else if (newResultType.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(newResultType)) {
|
||||||
result = b.create<arith::RemSIOp>(loc, self, other);
|
result = b.create<arith::RemSIOp>(loc, self, other);
|
||||||
} else {
|
} else {
|
||||||
remTensor.emitError(
|
remTensor.emitError(
|
||||||
|
@ -1303,12 +1303,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
|
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
|
||||||
Value result;
|
Value result;
|
||||||
|
|
||||||
if (newResultType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(newResultType)) {
|
||||||
Value n = b.create<arith::DivFOp>(loc, self, other);
|
Value n = b.create<arith::DivFOp>(loc, self, other);
|
||||||
n = b.create<math::TruncOp>(loc, n);
|
n = b.create<math::TruncOp>(loc, n);
|
||||||
Value n_y = b.create<arith::MulFOp>(loc, n, other);
|
Value n_y = b.create<arith::MulFOp>(loc, n, other);
|
||||||
result = b.create<arith::SubFOp>(loc, self, n_y);
|
result = b.create<arith::SubFOp>(loc, self, n_y);
|
||||||
} else if (newResultType.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(newResultType)) {
|
||||||
Value n = b.create<arith::DivSIOp>(loc, self, other);
|
Value n = b.create<arith::DivSIOp>(loc, self, other);
|
||||||
Value n_y = b.create<arith::MulIOp>(loc, n, other);
|
Value n_y = b.create<arith::MulIOp>(loc, n, other);
|
||||||
result = b.create<arith::SubIOp>(loc, self, n_y);
|
result = b.create<arith::SubIOp>(loc, self, n_y);
|
||||||
|
@ -1349,7 +1349,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
|
Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
|
||||||
|
|
||||||
Value predicate;
|
Value predicate;
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(dtype))
|
||||||
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
|
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
|
||||||
threshold);
|
threshold);
|
||||||
else
|
else
|
||||||
|
@ -1372,7 +1372,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||||
|
|
||||||
Value predicate;
|
Value predicate;
|
||||||
if (dtype.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(dtype))
|
||||||
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
|
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
|
||||||
threshold);
|
threshold);
|
||||||
else
|
else
|
||||||
|
@ -1426,7 +1426,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type elementType = converter->convertType(bitwiseNot.getType())
|
Type elementType = converter->convertType(bitwiseNot.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (elementType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(elementType)) {
|
||||||
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
|
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -2253,7 +2253,7 @@ public:
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
auto inputElementType = inputType.getElementType();
|
auto inputElementType = inputType.getElementType();
|
||||||
|
|
||||||
if (!inputElementType.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(inputElementType)) {
|
||||||
op.emitError("Logit does not support non-floating point type");
|
op.emitError("Logit does not support non-floating point type");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -554,7 +554,7 @@ FailureOr<Type> torch_to_linalg::getBackendTypeForScalarType(
|
||||||
}
|
}
|
||||||
Type type = *maybeType;
|
Type type = *maybeType;
|
||||||
// The linalg-on-tensors backend currently expects integers to be signless.
|
// The linalg-on-tensors backend currently expects integers to be signless.
|
||||||
if (auto intType = type.dyn_cast<IntegerType>()) {
|
if (auto intType = dyn_cast<IntegerType>(type)) {
|
||||||
type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless);
|
type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless);
|
||||||
}
|
}
|
||||||
return type;
|
return type;
|
||||||
|
|
|
@ -140,11 +140,11 @@ public:
|
||||||
|
|
||||||
// If the target type is non-torch type, then use TypeConverter to convert
|
// If the target type is non-torch type, then use TypeConverter to convert
|
||||||
// the type of the source.
|
// the type of the source.
|
||||||
if (targetType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(targetType)) {
|
||||||
targetType = Torch::FloatType::get(op->getContext());
|
targetType = Torch::FloatType::get(op->getContext());
|
||||||
torchArg = typeConverter->materializeSourceConversion(
|
torchArg = typeConverter->materializeSourceConversion(
|
||||||
rewriter, scfWhileOp.getLoc(), targetType, {to});
|
rewriter, scfWhileOp.getLoc(), targetType, {to});
|
||||||
} else if (targetType.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(targetType)) {
|
||||||
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
|
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
|
||||||
if (bitWidth == 1)
|
if (bitWidth == 1)
|
||||||
targetType = Torch::BoolType::get(op->getContext());
|
targetType = Torch::BoolType::get(op->getContext());
|
||||||
|
@ -179,7 +179,7 @@ public:
|
||||||
|
|
||||||
// If the argument is a torch tensor, directly add it in the list of
|
// If the argument is a torch tensor, directly add it in the list of
|
||||||
// iter args.
|
// iter args.
|
||||||
if (torchType.isa<Torch::BaseTensorType>()) {
|
if (isa<Torch::BaseTensorType>(torchType)) {
|
||||||
loopConditionIterArgs.push_back(torchArg);
|
loopConditionIterArgs.push_back(torchArg);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -262,11 +262,11 @@ public:
|
||||||
|
|
||||||
// If the target type is non-torch type, then use TypeConverter to convert
|
// If the target type is non-torch type, then use TypeConverter to convert
|
||||||
// the type of the source.
|
// the type of the source.
|
||||||
if (targetType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(targetType)) {
|
||||||
targetType = Torch::FloatType::get(op->getContext());
|
targetType = Torch::FloatType::get(op->getContext());
|
||||||
torchArg = typeConverter->materializeSourceConversion(
|
torchArg = typeConverter->materializeSourceConversion(
|
||||||
rewriter, scfForOp.getLoc(), targetType, {to});
|
rewriter, scfForOp.getLoc(), targetType, {to});
|
||||||
} else if (targetType.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(targetType)) {
|
||||||
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
|
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
|
||||||
if (bitWidth == 1)
|
if (bitWidth == 1)
|
||||||
targetType = Torch::BoolType::get(op->getContext());
|
targetType = Torch::BoolType::get(op->getContext());
|
||||||
|
|
|
@ -42,11 +42,11 @@ static Value getConstantLike(OpBuilder &b, Location loc, T constant,
|
||||||
Value val) {
|
Value val) {
|
||||||
Type ty = getElementTypeOrSelf(val.getType());
|
Type ty = getElementTypeOrSelf(val.getType());
|
||||||
auto getAttr = [&]() -> Attribute {
|
auto getAttr = [&]() -> Attribute {
|
||||||
if (ty.isa<mlir::IntegerType>())
|
if (isa<mlir::IntegerType>(ty))
|
||||||
return b.getIntegerAttr(ty, constant);
|
return b.getIntegerAttr(ty, constant);
|
||||||
if (ty.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(ty))
|
||||||
return b.getFloatAttr(ty, constant);
|
return b.getFloatAttr(ty, constant);
|
||||||
if (auto complexTy = ty.dyn_cast<mlir::ComplexType>())
|
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
|
||||||
return complex::NumberAttr::get(complexTy, constant, 0);
|
return complex::NumberAttr::get(complexTy, constant, 0);
|
||||||
llvm_unreachable("unhandled element type");
|
llvm_unreachable("unhandled element type");
|
||||||
};
|
};
|
||||||
|
@ -105,17 +105,17 @@ bool skipMultiplyAlpha(Value alphaValue) {
|
||||||
static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
auto constType = RankedTensorType::get({}, elementType);
|
auto constType = RankedTensorType::get({}, elementType);
|
||||||
if (elementType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(elementType)) {
|
||||||
auto constAttr = SplatElementsAttr::get(
|
auto constAttr = SplatElementsAttr::get(
|
||||||
constType,
|
constType,
|
||||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||||
/*negative=*/false));
|
/*negative=*/false));
|
||||||
return rewriter
|
return rewriter
|
||||||
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
if (elementType.isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(elementType)) {
|
||||||
auto integerType = elementType.cast<mlir::IntegerType>();
|
auto integerType = cast<mlir::IntegerType>(elementType);
|
||||||
DenseElementsAttr constAttr;
|
DenseElementsAttr constAttr;
|
||||||
if (integerType.isUnsigned()) {
|
if (integerType.isUnsigned()) {
|
||||||
constAttr = SplatElementsAttr::get(
|
constAttr = SplatElementsAttr::get(
|
||||||
|
@ -134,17 +134,17 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
||||||
static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
auto constType = RankedTensorType::get({}, elementType);
|
auto constType = RankedTensorType::get({}, elementType);
|
||||||
if (elementType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(elementType)) {
|
||||||
auto constAttr = SplatElementsAttr::get(
|
auto constAttr = SplatElementsAttr::get(
|
||||||
constType,
|
constType,
|
||||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||||
/*negative=*/true));
|
/*negative=*/true));
|
||||||
return rewriter
|
return rewriter
|
||||||
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
if (elementType.isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(elementType)) {
|
||||||
auto integerType = elementType.cast<mlir::IntegerType>();
|
auto integerType = cast<mlir::IntegerType>(elementType);
|
||||||
DenseElementsAttr constAttr;
|
DenseElementsAttr constAttr;
|
||||||
if (integerType.isUnsigned()) {
|
if (integerType.isUnsigned()) {
|
||||||
constAttr = SplatElementsAttr::get(
|
constAttr = SplatElementsAttr::get(
|
||||||
|
@ -446,7 +446,7 @@ public:
|
||||||
op, "only support constant str rounding mode");
|
op, "only support constant str rounding mode");
|
||||||
|
|
||||||
// if trunc and int, do nothing
|
// if trunc and int, do nothing
|
||||||
if (roundingMode == "trunc" && outElemTy.isa<mlir::FloatType>()) {
|
if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) {
|
||||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||||
// to C-style integer division.
|
// to C-style integer division.
|
||||||
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
|
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
|
||||||
|
@ -457,7 +457,7 @@ public:
|
||||||
if (roundingMode == "floor") {
|
if (roundingMode == "floor") {
|
||||||
// "floor" - rounds the results of the division down. Equivalent to
|
// "floor" - rounds the results of the division down. Equivalent to
|
||||||
// floor division in Python (the // operator)
|
// floor division in Python (the // operator)
|
||||||
if (outElemTy.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(outElemTy))
|
||||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||||
else if (!outElemTy.isUnsignedInteger()) {
|
else if (!outElemTy.isUnsignedInteger()) {
|
||||||
TensorType defaultIntToFloatType =
|
TensorType defaultIntToFloatType =
|
||||||
|
@ -518,10 +518,10 @@ public:
|
||||||
chlo::ComparisonTypeAttr compareTypeAttr;
|
chlo::ComparisonTypeAttr compareTypeAttr;
|
||||||
chlo::ComparisonDirectionAttr compareDirectionAttr;
|
chlo::ComparisonDirectionAttr compareDirectionAttr;
|
||||||
|
|
||||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(lhsElemTy)) {
|
||||||
compareTypeAttr = chlo::ComparisonTypeAttr::get(
|
compareTypeAttr = chlo::ComparisonTypeAttr::get(
|
||||||
op->getContext(), chlo::ComparisonType::FLOAT);
|
op->getContext(), chlo::ComparisonType::FLOAT);
|
||||||
} else if (lhsElemTy.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(lhsElemTy)) {
|
||||||
compareTypeAttr = chlo::ComparisonTypeAttr::get(
|
compareTypeAttr = chlo::ComparisonTypeAttr::get(
|
||||||
op->getContext(), chlo::ComparisonType::SIGNED);
|
op->getContext(), chlo::ComparisonType::SIGNED);
|
||||||
}
|
}
|
||||||
|
@ -985,14 +985,14 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||||
auto lhsElemTy = lhsTy.getElementType();
|
auto lhsElemTy = lhsTy.getElementType();
|
||||||
|
|
||||||
if (!lhsElemTy.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(lhsElemTy)) {
|
||||||
return op->emitError("only float tensor in relu op is supported");
|
return op->emitError("only float tensor in relu op is supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value zeroTensor;
|
Value zeroTensor;
|
||||||
zeroTensor = getConstantLike(
|
zeroTensor = getConstantLike(
|
||||||
rewriter, op->getLoc(),
|
rewriter, op->getLoc(),
|
||||||
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
|
APFloat::getZero(cast<mlir::FloatType>(lhsElemTy).getFloatSemantics(),
|
||||||
false),
|
false),
|
||||||
lhs);
|
lhs);
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
|
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
|
||||||
|
@ -1160,7 +1160,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
rewriter.getI64IntegerAttr(feature_index));
|
rewriter.getI64IntegerAttr(feature_index));
|
||||||
output = hlo::promoteType(rewriter, op.getLoc(),
|
output = hlo::promoteType(rewriter, op.getLoc(),
|
||||||
batchNormTrainingResult.getResult(0),
|
batchNormTrainingResult.getResult(0),
|
||||||
outputTy.cast<TensorType>());
|
cast<TensorType>(outputTy));
|
||||||
} else {
|
} else {
|
||||||
auto batchNormTrainingResult =
|
auto batchNormTrainingResult =
|
||||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||||
|
@ -1204,7 +1204,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
runningVar, rewriter.getF32FloatAttr(eps),
|
runningVar, rewriter.getF32FloatAttr(eps),
|
||||||
rewriter.getI64IntegerAttr(feature_index));
|
rewriter.getI64IntegerAttr(feature_index));
|
||||||
output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
|
output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
|
||||||
outputTy.cast<TensorType>());
|
cast<TensorType>(outputTy));
|
||||||
} else {
|
} else {
|
||||||
output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||||
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
||||||
|
@ -1478,7 +1478,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
auto dtype = outType.getElementType();
|
auto dtype = outType.getElementType();
|
||||||
if (!dtype.isa<mlir::IntegerType>() && !dtype.isa<mlir::FloatType>()) {
|
if (!isa<mlir::IntegerType>(dtype) && !isa<mlir::FloatType>(dtype)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: only int or float dtype supported");
|
op, "unimplemented: only int or float dtype supported");
|
||||||
}
|
}
|
||||||
|
@ -1607,7 +1607,7 @@ LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
|
||||||
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
|
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
|
||||||
loc, rewriter.getI64TensorAttr(elements));
|
loc, rewriter.getI64TensorAttr(elements));
|
||||||
auto outTy = getTypeConverter()->convertType(op.getType());
|
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||||
auto outElemTy = outTy.cast<RankedTensorType>().getElementType();
|
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
|
||||||
Value from =
|
Value from =
|
||||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
|
||||||
Value to =
|
Value to =
|
||||||
|
|
|
@ -34,14 +34,14 @@ static Value createInitialValueForGatherScatterOp(Operation *op,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
auto elementTy = constType.getElementType();
|
auto elementTy = constType.getElementType();
|
||||||
if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) {
|
if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) {
|
||||||
if (elementTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APFloat::getZero(
|
constType, {APFloat::getZero(
|
||||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
||||||
/*negative=*/false)});
|
/*negative=*/false)});
|
||||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
} else if (isa<mlir::IntegerType>(elementTy) &&
|
||||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||||
|
|
|
@ -37,14 +37,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
// Avg pooling
|
// Avg pooling
|
||||||
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
|
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
|
||||||
AtenCumsumOp>(op)) {
|
AtenCumsumOp>(op)) {
|
||||||
if (elementTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APFloat::getZero(
|
constType, {APFloat::getZero(
|
||||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
||||||
/*negative=*/false)});
|
/*negative=*/false)});
|
||||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
} else if (isa<mlir::IntegerType>(elementTy) &&
|
||||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||||
|
@ -55,14 +55,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
|
|
||||||
// Max pooling
|
// Max pooling
|
||||||
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
|
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
|
||||||
if (elementTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APFloat::getInf(
|
constType,
|
||||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
||||||
/*negative=*/true)});
|
/*negative=*/true)});
|
||||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
} else if (isa<mlir::IntegerType>(elementTy) &&
|
||||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType,
|
constType,
|
||||||
|
|
|
@ -37,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
auto constType = RankedTensorType::get({}, elementTy);
|
auto constType = RankedTensorType::get({}, elementTy);
|
||||||
if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
|
if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
|
||||||
AtenLinalgVectorNormOp>(op)) {
|
AtenLinalgVectorNormOp>(op)) {
|
||||||
if (elementTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APFloat::getZero(
|
constType, {APFloat::getZero(
|
||||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
||||||
/*negative=*/false)});
|
/*negative=*/false)});
|
||||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
} else if (isa<mlir::IntegerType>(elementTy) &&
|
||||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||||
|
@ -54,14 +54,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
|
if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
|
||||||
if (elementTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APFloat::getInf(
|
constType,
|
||||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
||||||
/*negative=*/true)});
|
/*negative=*/true)});
|
||||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
} else if (isa<mlir::IntegerType>(elementTy) &&
|
||||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType,
|
constType,
|
||||||
|
@ -72,14 +72,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<AtenMinOp>(op)) {
|
if (isa<AtenMinOp>(op)) {
|
||||||
if (elementTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APFloat::getInf(
|
constType,
|
||||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
||||||
/*negative=*/false)});
|
/*negative=*/false)});
|
||||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
} else if (isa<mlir::IntegerType>(elementTy) &&
|
||||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType,
|
constType,
|
||||||
|
@ -234,7 +234,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
"only floating-point or integer datatype legalization supported");
|
"only floating-point or integer datatype legalization supported");
|
||||||
}
|
}
|
||||||
// Currently, (u)int8 dtype is not supported!
|
// Currently, (u)int8 dtype is not supported!
|
||||||
if (inputElemTy.isa<mlir::IntegerType>() &&
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
@ -305,7 +305,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
"Only floating-point or integer datatype legalization supported");
|
"Only floating-point or integer datatype legalization supported");
|
||||||
}
|
}
|
||||||
// Currently, (u)int8 dtype is not supported
|
// Currently, (u)int8 dtype is not supported
|
||||||
if (inputElemTy.isa<mlir::IntegerType>() &&
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
@ -319,7 +319,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
->convertType(op.getResult(1).getType())
|
->convertType(op.getResult(1).getType())
|
||||||
.template cast<RankedTensorType>();
|
.template cast<RankedTensorType>();
|
||||||
Type idxElementType = idxResultType.getElementType();
|
Type idxElementType = idxResultType.getElementType();
|
||||||
if (!idxElementType.isa<mlir::IntegerType>()) {
|
if (!isa<mlir::IntegerType>(idxElementType)) {
|
||||||
return op.emitError("Aten.max.dim needs integer-like result");
|
return op.emitError("Aten.max.dim needs integer-like result");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -404,7 +404,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
||||||
"only floating-point or integer datatype legalization supported");
|
"only floating-point or integer datatype legalization supported");
|
||||||
}
|
}
|
||||||
// Currently, (u)int8 dtype is not supported
|
// Currently, (u)int8 dtype is not supported
|
||||||
if (inputElemTy.isa<mlir::IntegerType>() &&
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
@ -466,7 +466,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
||||||
"only floating-point or integer datatype legalization supported");
|
"only floating-point or integer datatype legalization supported");
|
||||||
}
|
}
|
||||||
// Currently, (u)int8 dtype is not supported
|
// Currently, (u)int8 dtype is not supported
|
||||||
if (inputElemTy.isa<mlir::IntegerType>() &&
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
@ -529,7 +529,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
|
||||||
"only floating-point or integer datatype legalization supported");
|
"only floating-point or integer datatype legalization supported");
|
||||||
}
|
}
|
||||||
// Currently, (u)int8 dtype is not supported
|
// Currently, (u)int8 dtype is not supported
|
||||||
if (inputElemTy.isa<mlir::IntegerType>() &&
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
@ -603,7 +603,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Currently, (u)int8 dtype is not supported
|
// Currently, (u)int8 dtype is not supported
|
||||||
if (inputElemTy.isa<mlir::IntegerType>() &&
|
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
|
@ -715,7 +715,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
auto inputElemType = inputType.getElementType();
|
auto inputElemType = inputType.getElementType();
|
||||||
if (!inputElemType.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(inputElemType)) {
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"only float dtype allowed in input tensor of AtenFrobeniusNormDimOp");
|
"only float dtype allowed in input tensor of AtenFrobeniusNormDimOp");
|
||||||
}
|
}
|
||||||
|
@ -830,7 +830,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
auto outElemType = outType.getElementType();
|
auto outElemType = outType.getElementType();
|
||||||
if (!outElemType.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(outElemType)) {
|
||||||
return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp");
|
return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -912,7 +912,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
||||||
op->getLoc(), blockArgumentTy,
|
op->getLoc(), blockArgumentTy,
|
||||||
DenseElementsAttr::get(
|
DenseElementsAttr::get(
|
||||||
blockArgumentTy,
|
blockArgumentTy,
|
||||||
APFloat(outElemType.cast<mlir::FloatType>().getFloatSemantics(), 1)));
|
APFloat(cast<mlir::FloatType>(outElemType).getFloatSemantics(), 1)));
|
||||||
auto reciprocalOrd = rewriter.create<stablehlo::DivOp>(
|
auto reciprocalOrd = rewriter.create<stablehlo::DivOp>(
|
||||||
op->getLoc(), blockArgumentTy, constantOne, ord);
|
op->getLoc(), blockArgumentTy, constantOne, ord);
|
||||||
auto output = rewriter.create<chlo::BroadcastPowOp>(
|
auto output = rewriter.create<chlo::BroadcastPowOp>(
|
||||||
|
|
|
@ -144,12 +144,12 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Unable to extract the scalar constant");
|
"Unable to extract the scalar constant");
|
||||||
|
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
|
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
|
||||||
(isFloat ? doubleValue : intValue),
|
(isFloat ? doubleValue : intValue),
|
||||||
dshape, dtype)
|
dshape, dtype)
|
||||||
.value();
|
.value();
|
||||||
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||||
auto w = intType.getWidth();
|
auto w = intType.getWidth();
|
||||||
if (w != 1 && w != 32 && w != 64)
|
if (w != 1 && w != 32 && w != 64)
|
||||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||||
|
@ -261,7 +261,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
Type rhsAlphaMulElemType;
|
Type rhsAlphaMulElemType;
|
||||||
if (outElemTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(outElemTy)) {
|
||||||
rhsAlphaMulElemType = outElemTy;
|
rhsAlphaMulElemType = outElemTy;
|
||||||
} else {
|
} else {
|
||||||
// if output type is 64, input type should also be 32
|
// if output type is 64, input type should also be 32
|
||||||
|
@ -355,7 +355,7 @@ public:
|
||||||
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
|
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenBitwiseOrTensorOp>() ||
|
std::is_same<AtenOpT, AtenBitwiseOrTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenBitwiseXorTensorOp>();
|
std::is_same<AtenOpT, AtenBitwiseXorTensorOp>();
|
||||||
if (lhsElemTy.isa<mlir::FloatType>() && isBitwiseOp) {
|
if (isa<mlir::FloatType>(lhsElemTy) && isBitwiseOp) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"For bitwise operators, only integer "
|
"For bitwise operators, only integer "
|
||||||
"datatype legalization is supported");
|
"datatype legalization is supported");
|
||||||
|
@ -442,8 +442,7 @@ public:
|
||||||
rhsTensor = rhsType ? rhs : rhsAsTensor;
|
rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (outElemTy.isa<mlir::FloatType>() ||
|
if (isa<mlir::FloatType>(outElemTy) || isa<mlir::IntegerType>(outElemTy)) {
|
||||||
outElemTy.isa<mlir::IntegerType>()) {
|
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template cast<TensorType>();
|
.template cast<TensorType>();
|
||||||
|
@ -1454,7 +1453,7 @@ public:
|
||||||
SmallVector<int64_t> matmulOutputShape(
|
SmallVector<int64_t> matmulOutputShape(
|
||||||
{matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]});
|
{matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]});
|
||||||
Type outputElemTy;
|
Type outputElemTy;
|
||||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(lhsElemTy)) {
|
||||||
outputElemTy = lhsElemTy;
|
outputElemTy = lhsElemTy;
|
||||||
} else { // qint8 emits i32 matmul output
|
} else { // qint8 emits i32 matmul output
|
||||||
outputElemTy = rewriter.getIntegerType(32);
|
outputElemTy = rewriter.getIntegerType(32);
|
||||||
|
@ -1898,7 +1897,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
|
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
|
||||||
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
|
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
|
||||||
// define a 48-bit int.
|
// define a 48-bit int.
|
||||||
if (inputElemTy.isa<quant::QuantizedType>()) {
|
if (isa<quant::QuantizedType>(inputElemTy)) {
|
||||||
SmallVector<int32_t> zeroVec(weightShape[0], 0);
|
SmallVector<int32_t> zeroVec(weightShape[0], 0);
|
||||||
bias = tosa::getConstTensor<int32_t>(
|
bias = tosa::getConstTensor<int32_t>(
|
||||||
rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])})
|
rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])})
|
||||||
|
@ -1915,7 +1914,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
op, "Bias provided but not a ranked tensor");
|
op, "Bias provided but not a ranked tensor");
|
||||||
}
|
}
|
||||||
auto biasElemTy =
|
auto biasElemTy =
|
||||||
inputElemTy.isa<mlir::FloatType>() ? inputElemTy : rewriter.getI32Type();
|
isa<mlir::FloatType>(inputElemTy) ? inputElemTy : rewriter.getI32Type();
|
||||||
|
|
||||||
int64_t groups;
|
int64_t groups;
|
||||||
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) {
|
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) {
|
||||||
|
@ -2098,7 +2097,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
Value rescaledResult = transposedOutput;
|
Value rescaledResult = transposedOutput;
|
||||||
if (inputElemTy.isa<quant::QuantizedType>()) {
|
if (isa<quant::QuantizedType>(inputElemTy)) {
|
||||||
rescaledResult = tosa::buildRescaleOpConvOutput(
|
rescaledResult = tosa::buildRescaleOpConvOutput(
|
||||||
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
|
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
|
||||||
}
|
}
|
||||||
|
@ -2230,7 +2229,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
if (toBcastType.getRank() > 1)
|
if (toBcastType.getRank() > 1)
|
||||||
return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1");
|
return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1");
|
||||||
|
|
||||||
RankedTensorType outTensorType = outType.cast<RankedTensorType>();
|
RankedTensorType outTensorType = cast<RankedTensorType>(outType);
|
||||||
SmallVector<int64_t> newShape = {
|
SmallVector<int64_t> newShape = {
|
||||||
makeShapeTorchCompatible(toBcastType.getShape())[0]};
|
makeShapeTorchCompatible(toBcastType.getShape())[0]};
|
||||||
for (auto i = 2; i < outTensorType.getRank(); ++i)
|
for (auto i = 2; i < outTensorType.getRank(); ++i)
|
||||||
|
@ -2677,7 +2676,7 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
||||||
op, "Only floating-point or integer datatype legalization supported");
|
op, "Only floating-point or integer datatype legalization supported");
|
||||||
|
|
||||||
// Integer types with width > 32 are not supported
|
// Integer types with width > 32 are not supported
|
||||||
auto selfIntType = selfElemTy.dyn_cast<IntegerType>();
|
auto selfIntType = dyn_cast<IntegerType>(selfElemTy);
|
||||||
if (selfIntType && selfIntType.getWidth() > 32) {
|
if (selfIntType && selfIntType.getWidth() > 32) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Integer types with width greater than 32 are not supported");
|
op, "Integer types with width greater than 32 are not supported");
|
||||||
|
@ -2956,7 +2955,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
|
||||||
auto selfElemTy = selfType.getElementType();
|
auto selfElemTy = selfType.getElementType();
|
||||||
if (!selfElemTy.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(selfElemTy)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point datatype legalization supported");
|
op, "Only floating-point datatype legalization supported");
|
||||||
}
|
}
|
||||||
|
@ -2993,7 +2992,7 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
|
||||||
auto selfElemTy = selfType.getElementType();
|
auto selfElemTy = selfType.getElementType();
|
||||||
if (!selfElemTy.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(selfElemTy)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point datatype legalization supported");
|
op, "Only floating-point datatype legalization supported");
|
||||||
}
|
}
|
||||||
|
@ -3057,7 +3056,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Integer types with width > 32 are not supported
|
// Integer types with width > 32 are not supported
|
||||||
auto selfIntType = selfElemTy.dyn_cast<IntegerType>();
|
auto selfIntType = dyn_cast<IntegerType>(selfElemTy);
|
||||||
if (selfIntType && selfIntType.getWidth() > 32) {
|
if (selfIntType && selfIntType.getWidth() > 32) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Integer types with width greater than 32 are not supported");
|
op, "Integer types with width greater than 32 are not supported");
|
||||||
|
@ -4235,7 +4234,7 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||||
"Unable to extract the scalar constant");
|
"Unable to extract the scalar constant");
|
||||||
|
|
||||||
auto outElemTy = resultType.getElementType();
|
auto outElemTy = resultType.getElementType();
|
||||||
if (outElemTy.isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(outElemTy)) {
|
||||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(
|
rewriter.replaceOpWithNewOp<tosa::ConstOp>(
|
||||||
op, resultType, DenseElementsAttr::get(resultType, {intValue}));
|
op, resultType, DenseElementsAttr::get(resultType, {intValue}));
|
||||||
} else if (outElemTy.isF64()) {
|
} else if (outElemTy.isF64()) {
|
||||||
|
@ -4383,7 +4382,7 @@ LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto divTensor = self;
|
auto divTensor = self;
|
||||||
// tosa::DivOp only supports int
|
// tosa::DivOp only supports int
|
||||||
if (outElemTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(outElemTy)) {
|
||||||
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
|
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
|
||||||
op.getLoc(), otherTensor.getType(), otherTensor);
|
op.getLoc(), otherTensor.getType(), otherTensor);
|
||||||
divTensor = rewriter.create<tosa::MulOp>(
|
divTensor = rewriter.create<tosa::MulOp>(
|
||||||
|
|
|
@ -119,7 +119,7 @@ tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
||||||
Value lhs, Value rhs) {
|
Value lhs, Value rhs) {
|
||||||
auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType();
|
auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType();
|
||||||
auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType();
|
auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType();
|
||||||
if (lhsElemTy.isa<mlir::FloatType>() || rhsElemTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(lhsElemTy) || isa<mlir::FloatType>(rhsElemTy)) {
|
||||||
(void)rewriter.notifyMatchFailure(op,
|
(void)rewriter.notifyMatchFailure(op,
|
||||||
"tosa.div only supports integer type");
|
"tosa.div only supports integer type");
|
||||||
}
|
}
|
||||||
|
@ -213,7 +213,7 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
||||||
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
|
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
|
||||||
Type outType, Value paramsValue,
|
Type outType, Value paramsValue,
|
||||||
Value indicesValue) {
|
Value indicesValue) {
|
||||||
auto resultType = outType.dyn_cast<ShapedType>();
|
auto resultType = dyn_cast<ShapedType>(outType);
|
||||||
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
||||||
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
|
@ -419,7 +419,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
|
||||||
Operation *op, Type outType,
|
Operation *op, Type outType,
|
||||||
Value paramsValue, Value indicesValue,
|
Value paramsValue, Value indicesValue,
|
||||||
Value fillValues) {
|
Value fillValues) {
|
||||||
auto resultType = outType.dyn_cast<ShapedType>();
|
auto resultType = dyn_cast<ShapedType>(outType);
|
||||||
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
||||||
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
||||||
auto fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
|
auto fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -981,7 +981,7 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
Type elemType = output_type.getElementType();
|
Type elemType = output_type.getElementType();
|
||||||
if (!elemType.isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(elemType)) {
|
||||||
op->emitOpError("Only floating-point datatype legalization supported for "
|
op->emitOpError("Only floating-point datatype legalization supported for "
|
||||||
"AtenLinalgVectorNorm op");
|
"AtenLinalgVectorNorm op");
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
|
@ -154,7 +154,7 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||||
// Create a zero constant tensor of the desired type and shape.
|
// Create a zero constant tensor of the desired type and shape.
|
||||||
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
|
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
|
||||||
Operation *op, Type type) {
|
Operation *op, Type type) {
|
||||||
RankedTensorType resultType = type.dyn_cast<RankedTensorType>();
|
RankedTensorType resultType = dyn_cast<RankedTensorType>(type);
|
||||||
|
|
||||||
if (!resultType) {
|
if (!resultType) {
|
||||||
(void)rewriter.notifyMatchFailure(op, "not ranked tensor type");
|
(void)rewriter.notifyMatchFailure(op, "not ranked tensor type");
|
||||||
|
@ -167,7 +167,7 @@ std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
|
||||||
Attribute zeroAttr = rewriter.getZeroAttr(zeroType);
|
Attribute zeroAttr = rewriter.getZeroAttr(zeroType);
|
||||||
|
|
||||||
return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zeroType,
|
return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zeroType,
|
||||||
zeroAttr.cast<ElementsAttr>())
|
cast<ElementsAttr>(zeroAttr))
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -312,7 +312,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
||||||
Value src, Type destType, Value &result) {
|
Value src, Type destType, Value &result) {
|
||||||
|
|
||||||
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType();
|
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType();
|
||||||
Type destElemTy = destType.dyn_cast<TensorType>().getElementType();
|
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
|
||||||
|
|
||||||
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
|
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -392,7 +392,7 @@ LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
|
||||||
// Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time
|
// Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time
|
||||||
// FP16 is supported, the accumulator type can be selected based on trade-off
|
// FP16 is supported, the accumulator type can be selected based on trade-off
|
||||||
// between performance and accuracy. Set to FP32 by default.
|
// between performance and accuracy. Set to FP32 by default.
|
||||||
accType = inputETy.isa<FloatType>()
|
accType = isa<FloatType>(inputETy)
|
||||||
? mlir::TypeAttr::get(rewriter.getF32Type())
|
? mlir::TypeAttr::get(rewriter.getF32Type())
|
||||||
: mlir::TypeAttr::get(rewriter.getIntegerType(32));
|
: mlir::TypeAttr::get(rewriter.getIntegerType(32));
|
||||||
|
|
||||||
|
|
|
@ -27,9 +27,9 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
||||||
// TODO: Remove this check but use a separate verification pass to verify the
|
// TODO: Remove this check but use a separate verification pass to verify the
|
||||||
// invariants expected by later passes.
|
// invariants expected by later passes.
|
||||||
auto isValidLinalgType = [](Type type) {
|
auto isValidLinalgType = [](Type type) {
|
||||||
if (type.isa<NonValueTensorType>())
|
if (isa<NonValueTensorType>(type))
|
||||||
return false;
|
return false;
|
||||||
auto tensor = type.dyn_cast<ValueTensorType>();
|
auto tensor = dyn_cast<ValueTensorType>(type);
|
||||||
return !tensor ||
|
return !tensor ||
|
||||||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
|
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
|
||||||
};
|
};
|
||||||
|
@ -43,8 +43,8 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
||||||
|
|
||||||
LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) {
|
LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) {
|
||||||
Type type = v.getType();
|
Type type = v.getType();
|
||||||
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() ||
|
if (isa<OptionalType>(type) || isa<Torch::NoneType>(type) ||
|
||||||
type.isa<mlir::NoneType>())
|
isa<mlir::NoneType>(type))
|
||||||
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
|
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
|
||||||
Type lhsType = lhsDim.getType();
|
Type lhsType = lhsDim.getType();
|
||||||
Type rhsType = rhsDim.getType();
|
Type rhsType = rhsDim.getType();
|
||||||
auto checkIntOrIndex = [](Type type) {
|
auto checkIntOrIndex = [](Type type) {
|
||||||
assert((type.isa<IntegerType>() || type.isa<IndexType>()) &&
|
assert((isa<IntegerType>(type) || isa<IndexType>(type)) &&
|
||||||
"must be either integer or index type");
|
"must be either integer or index type");
|
||||||
};
|
};
|
||||||
checkIntOrIndex(lhsType);
|
checkIntOrIndex(lhsType);
|
||||||
|
@ -198,13 +198,13 @@ Value getTensorSize(OpBuilder &b, Location loc, Value tensor) {
|
||||||
// Creates a constant of type `elemType` with value `val`.
|
// Creates a constant of type `elemType` with value `val`.
|
||||||
Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) {
|
Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) {
|
||||||
TypedAttr attr = {};
|
TypedAttr attr = {};
|
||||||
if (elemType.isa<mlir::FloatType>())
|
if (isa<mlir::FloatType>(elemType))
|
||||||
attr = b.getFloatAttr(elemType, val);
|
attr = b.getFloatAttr(elemType, val);
|
||||||
if (elemType.isa<mlir::IndexType>())
|
if (isa<mlir::IndexType>(elemType))
|
||||||
attr = b.getIndexAttr(val);
|
attr = b.getIndexAttr(val);
|
||||||
if (elemType.isa<mlir::IntegerType>())
|
if (isa<mlir::IntegerType>(elemType))
|
||||||
attr = b.getIntegerAttr(
|
attr = b.getIntegerAttr(elemType,
|
||||||
elemType, APInt(elemType.cast<IntegerType>().getWidth(), val));
|
APInt(cast<IntegerType>(elemType).getWidth(), val));
|
||||||
if (!attr)
|
if (!attr)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return b.create<arith::ConstantOp>(loc, elemType, attr);
|
return b.create<arith::ConstantOp>(loc, elemType, attr);
|
||||||
|
@ -264,7 +264,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
return scalar;
|
return scalar;
|
||||||
|
|
||||||
auto isByteOrChar = [](Type type) {
|
auto isByteOrChar = [](Type type) {
|
||||||
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
|
if (auto integerTy = dyn_cast<mlir::IntegerType>(type)) {
|
||||||
return integerTy.getWidth() == 8;
|
return integerTy.getWidth() == 8;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -303,10 +303,10 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
if (dtype.isSignlessInteger(1)) {
|
if (dtype.isSignlessInteger(1)) {
|
||||||
Type scalarType = scalar.getType();
|
Type scalarType = scalar.getType();
|
||||||
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(scalarType));
|
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(scalarType));
|
||||||
if (scalarType.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(scalarType)) {
|
||||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, scalar,
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, scalar,
|
||||||
cstZero);
|
cstZero);
|
||||||
} else if (scalarType.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(scalarType)) {
|
||||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, scalar,
|
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, scalar,
|
||||||
cstZero);
|
cstZero);
|
||||||
} else {
|
} else {
|
||||||
|
@ -317,14 +317,14 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
if (auto dtypeFloat = dyn_cast<mlir::FloatType>(dtype)) {
|
||||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType)) {
|
||||||
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
||||||
return b.create<arith::TruncFOp>(loc, dtype, scalar);
|
return b.create<arith::TruncFOp>(loc, dtype, scalar);
|
||||||
// Only scalarFloat width < dtypeFloat width can reach here.
|
// Only scalarFloat width < dtypeFloat width can reach here.
|
||||||
return b.create<arith::ExtFOp>(loc, dtype, scalar);
|
return b.create<arith::ExtFOp>(loc, dtype, scalar);
|
||||||
}
|
}
|
||||||
assert(scalarType.isa<mlir::IntegerType>());
|
assert(isa<mlir::IntegerType>(scalarType));
|
||||||
if (scalarType.isSignlessInteger(1) ||
|
if (scalarType.isSignlessInteger(1) ||
|
||||||
(srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger()))
|
(srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger()))
|
||||||
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
|
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
|
||||||
|
@ -333,11 +333,11 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
|
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
|
if (auto dtypeInteger = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
|
if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType))
|
||||||
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
|
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
|
||||||
assert(scalarType.isa<mlir::IntegerType>());
|
assert(isa<mlir::IntegerType>(scalarType));
|
||||||
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
|
auto scalarInteger = cast<mlir::IntegerType>(scalarType);
|
||||||
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
||||||
return b.create<arith::TruncIOp>(loc, dtype, scalar);
|
return b.create<arith::TruncIOp>(loc, dtype, scalar);
|
||||||
if (scalarType.isSignlessInteger(1) ||
|
if (scalarType.isSignlessInteger(1) ||
|
||||||
|
|
|
@ -49,7 +49,7 @@ allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp,
|
||||||
size_t resultIndex = en.index();
|
size_t resultIndex = en.index();
|
||||||
Type resultType = en.value();
|
Type resultType = en.value();
|
||||||
|
|
||||||
auto tensorType = resultType.dyn_cast<RankedTensorType>();
|
auto tensorType = dyn_cast<RankedTensorType>(resultType);
|
||||||
if (tensorType == nullptr) {
|
if (tensorType == nullptr) {
|
||||||
tmtensorOp.emitOpError()
|
tmtensorOp.emitOpError()
|
||||||
<< "tensor to buffer conversion expects ranked tensor results";
|
<< "tensor to buffer conversion expects ranked tensor results";
|
||||||
|
|
|
@ -100,10 +100,12 @@ void TorchDialect::initialize() {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||||
|
|
||||||
>();
|
>();
|
||||||
addTypes<
|
addTypes<
|
||||||
#define GET_TYPEDEF_LIST
|
#define GET_TYPEDEF_LIST
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||||
|
|
||||||
>();
|
>();
|
||||||
addInterfaces<TorchInlinerInterface>();
|
addInterfaces<TorchInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
@ -144,35 +146,34 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
|
||||||
Operation *TorchDialect::materializeConstant(OpBuilder &builder,
|
Operation *TorchDialect::materializeConstant(OpBuilder &builder,
|
||||||
Attribute value, Type type,
|
Attribute value, Type type,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
if (auto integerType = type.dyn_cast<Torch::IntType>())
|
if (auto integerType = dyn_cast<Torch::IntType>(type))
|
||||||
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>());
|
return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value));
|
||||||
|
|
||||||
if (auto floatType = type.dyn_cast<Torch::FloatType>())
|
if (auto floatType = dyn_cast<Torch::FloatType>(type))
|
||||||
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>());
|
return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value));
|
||||||
|
|
||||||
if (auto numberType = type.dyn_cast<Torch::NumberType>()) {
|
if (auto numberType = dyn_cast<Torch::NumberType>(type)) {
|
||||||
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
|
if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
|
||||||
return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
|
return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
|
||||||
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
|
} else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
|
||||||
return builder.create<Torch::ConstantNumberOp>(loc, intValue);
|
return builder.create<Torch::ConstantNumberOp>(loc, intValue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type.isa<Torch::BoolType>()) {
|
if (isa<Torch::BoolType>(type)) {
|
||||||
return builder.create<Torch::ConstantBoolOp>(loc,
|
return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value));
|
||||||
value.cast<IntegerAttr>());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type.isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(type))
|
||||||
return builder.create<ConstantNoneOp>(loc);
|
return builder.create<ConstantNoneOp>(loc);
|
||||||
|
|
||||||
if (auto stringAttr = value.dyn_cast<StringAttr>())
|
if (auto stringAttr = dyn_cast<StringAttr>(value))
|
||||||
return builder.create<ConstantStrOp>(loc, stringAttr);
|
return builder.create<ConstantStrOp>(loc, stringAttr);
|
||||||
|
|
||||||
if (auto elementsAttr = value.dyn_cast<ElementsAttr>()) {
|
if (auto elementsAttr = dyn_cast<ElementsAttr>(value)) {
|
||||||
// Only !torch.vtensor can be constant folded. !torch.tensor has
|
// Only !torch.vtensor can be constant folded. !torch.tensor has
|
||||||
// non-trivial aliasing semantics which prevent deduplicating it.
|
// non-trivial aliasing semantics which prevent deduplicating it.
|
||||||
assert(type.isa<ValueTensorType>() && "should be a vtensor type!");
|
assert(isa<ValueTensorType>(type) && "should be a vtensor type!");
|
||||||
return builder.create<ValueTensorLiteralOp>(loc, elementsAttr);
|
return builder.create<ValueTensorLiteralOp>(loc, elementsAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,9 +41,8 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
|
||||||
return value;
|
return value;
|
||||||
|
|
||||||
// If the type is a tensor, then adjust the static information.
|
// If the type is a tensor, then adjust the static information.
|
||||||
if ((type.isa<ValueTensorType>() && desiredType.isa<ValueTensorType>()) ||
|
if ((isa<ValueTensorType>(type) && isa<ValueTensorType>(desiredType)) ||
|
||||||
(type.isa<NonValueTensorType>() &&
|
(isa<NonValueTensorType>(type) && isa<NonValueTensorType>(desiredType))) {
|
||||||
desiredType.isa<NonValueTensorType>())) {
|
|
||||||
Value adjusted = builder.create<TensorStaticInfoCastOp>(value.getLoc(),
|
Value adjusted = builder.create<TensorStaticInfoCastOp>(value.getLoc(),
|
||||||
desiredType, value);
|
desiredType, value);
|
||||||
return adjusted;
|
return adjusted;
|
||||||
|
@ -90,7 +89,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||||
// then we do the copy by going to a value tensor and back.
|
// then we do the copy by going to a value tensor and back.
|
||||||
if (tensor.getType().isa<NonValueTensorType>())
|
if (tensor.getType().isa<NonValueTensorType>())
|
||||||
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
|
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
|
||||||
if (newType.isa<NonValueTensorType>())
|
if (isa<NonValueTensorType>(newType))
|
||||||
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
|
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
|
||||||
|
|
||||||
return tensor;
|
return tensor;
|
||||||
|
@ -132,11 +131,11 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
||||||
static Value getScalarIntValue(Value input, Location loc,
|
static Value getScalarIntValue(Value input, Location loc,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
auto inputType = input.getType();
|
auto inputType = input.getType();
|
||||||
if (inputType.isa<Torch::IntType>()) {
|
if (isa<Torch::IntType>(inputType)) {
|
||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
|
auto inputTensorType = dyn_cast<BaseTensorType>(inputType);
|
||||||
if (!inputTensorType)
|
if (!inputTensorType)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
@ -166,11 +165,11 @@ static Value getScalarIntValue(Value input, Location loc,
|
||||||
static Value getScalarFloatValue(Value input, Location loc,
|
static Value getScalarFloatValue(Value input, Location loc,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
auto inputType = input.getType();
|
auto inputType = input.getType();
|
||||||
if (inputType.isa<Torch::FloatType>()) {
|
if (isa<Torch::FloatType>(inputType)) {
|
||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
|
auto inputTensorType = dyn_cast<BaseTensorType>(inputType);
|
||||||
if (!inputTensorType)
|
if (!inputTensorType)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
@ -273,7 +272,7 @@ LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||||
|
|
||||||
LogicalResult PrimListConstructOp::verify() {
|
LogicalResult PrimListConstructOp::verify() {
|
||||||
auto resultType = getResult().getType();
|
auto resultType = getResult().getType();
|
||||||
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType();
|
auto resultElementType = dyn_cast<ListType>(resultType).getContainedType();
|
||||||
auto matchResultElementType = [&](Type type) {
|
auto matchResultElementType = [&](Type type) {
|
||||||
return isValidSubtype(type, resultElementType);
|
return isValidSubtype(type, resultElementType);
|
||||||
};
|
};
|
||||||
|
@ -606,7 +605,7 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
|
||||||
Type rhsType = rhs.getType();
|
Type rhsType = rhs.getType();
|
||||||
|
|
||||||
// If either type is a NoneType, make it be the lhsType.
|
// If either type is a NoneType, make it be the lhsType.
|
||||||
if (rhsType.isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(rhsType)) {
|
||||||
std::swap(lhsType, rhsType);
|
std::swap(lhsType, rhsType);
|
||||||
std::swap(lhs, rhs);
|
std::swap(lhs, rhs);
|
||||||
}
|
}
|
||||||
|
@ -615,14 +614,14 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
|
||||||
|
|
||||||
// If both types are the singleton `!torch.none` type, then we don't even need
|
// If both types are the singleton `!torch.none` type, then we don't even need
|
||||||
// to look at the values.
|
// to look at the values.
|
||||||
if (lhsType.isa<Torch::NoneType>() && rhsType.isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(lhsType) && isa<Torch::NoneType>(rhsType))
|
||||||
return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue);
|
return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue);
|
||||||
|
|
||||||
// If neither type is a subtype of the other, then the result is false.
|
// If neither type is a subtype of the other, then the result is false.
|
||||||
// TODO: Implement and use subtype infra for this.
|
// TODO: Implement and use subtype infra for this.
|
||||||
// For now, check a specific case.
|
// For now, check a specific case.
|
||||||
// If the rhs is not OptionalType, then we know it cannot be None.
|
// If the rhs is not OptionalType, then we know it cannot be None.
|
||||||
if (lhsType.isa<Torch::NoneType>() && !rhsType.isa<Torch::OptionalType>()) {
|
if (isa<Torch::NoneType>(lhsType) && !isa<Torch::OptionalType>(rhsType)) {
|
||||||
return IntegerAttr::get(IntegerType::get(op->getContext(), 1),
|
return IntegerAttr::get(IntegerType::get(op->getContext(), 1),
|
||||||
!equalIsTrue);
|
!equalIsTrue);
|
||||||
}
|
}
|
||||||
|
@ -640,9 +639,9 @@ OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
|
||||||
auto step = adaptor.getStep();
|
auto step = adaptor.getStep();
|
||||||
if (!lo || !hi || !step)
|
if (!lo || !hi || !step)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue();
|
auto loInt = dyn_cast_or_null<IntegerAttr>(lo).getValue();
|
||||||
auto hiInt = hi.dyn_cast_or_null<IntegerAttr>().getValue();
|
auto hiInt = dyn_cast_or_null<IntegerAttr>(hi).getValue();
|
||||||
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
|
auto stepInt = dyn_cast_or_null<IntegerAttr>(step).getValue();
|
||||||
// TODO: Implement folding for negative steps.
|
// TODO: Implement folding for negative steps.
|
||||||
if (stepInt.isNegative())
|
if (stepInt.isNegative())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -650,7 +649,7 @@ OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
|
||||||
// r[i] = lo + step*i such that i >= 0 and r[i] < hi
|
// r[i] = lo + step*i such that i >= 0 and r[i] < hi
|
||||||
// So maximize `i` such that lo + step * i < hi
|
// So maximize `i` such that lo + step * i < hi
|
||||||
// ==> i == ceildiv(hi - lo, step)
|
// ==> i == ceildiv(hi - lo, step)
|
||||||
return IntegerAttr::get(lo.cast<TypedAttr>().getType(),
|
return IntegerAttr::get(cast<TypedAttr>(lo).getType(),
|
||||||
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
|
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
|
||||||
APInt::Rounding::UP));
|
APInt::Rounding::UP));
|
||||||
}
|
}
|
||||||
|
@ -665,10 +664,10 @@ OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
|
||||||
auto step = adaptor.getStep();
|
auto step = adaptor.getStep();
|
||||||
if (!index || !start || !step)
|
if (!index || !start || !step)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
|
auto indexInt = dyn_cast_or_null<IntegerAttr>(index).getValue();
|
||||||
auto startInt = start.dyn_cast_or_null<IntegerAttr>().getValue();
|
auto startInt = dyn_cast_or_null<IntegerAttr>(start).getValue();
|
||||||
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
|
auto stepInt = dyn_cast_or_null<IntegerAttr>(step).getValue();
|
||||||
return IntegerAttr::get(index.cast<TypedAttr>().getType(),
|
return IntegerAttr::get(cast<TypedAttr>(index).getType(),
|
||||||
startInt + stepInt * indexInt);
|
startInt + stepInt * indexInt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2768,9 +2767,9 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns(
|
||||||
|
|
||||||
Value constValue;
|
Value constValue;
|
||||||
Attribute value = op.getValueAttr();
|
Attribute value = op.getValueAttr();
|
||||||
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
|
if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
|
||||||
constValue = rewriter.create<Torch::ConstantFloatOp>(loc, floatValue);
|
constValue = rewriter.create<Torch::ConstantFloatOp>(loc, floatValue);
|
||||||
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
|
} else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
|
||||||
constValue = rewriter.create<Torch::ConstantIntOp>(loc, intValue);
|
constValue = rewriter.create<Torch::ConstantIntOp>(loc, intValue);
|
||||||
} else {
|
} else {
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -3192,9 +3191,9 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
|
||||||
BinaryFloatOperatorFn f) {
|
BinaryFloatOperatorFn f) {
|
||||||
double lhs, rhs;
|
double lhs, rhs;
|
||||||
auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool {
|
auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool {
|
||||||
if (auto intLhs = attr.dyn_cast_or_null<IntegerAttr>()) {
|
if (auto intLhs = dyn_cast_or_null<IntegerAttr>(attr)) {
|
||||||
value = static_cast<double>(intLhs.getValue().getSExtValue());
|
value = static_cast<double>(intLhs.getValue().getSExtValue());
|
||||||
} else if (auto floatLhs = attr.dyn_cast_or_null<FloatAttr>()) {
|
} else if (auto floatLhs = dyn_cast_or_null<FloatAttr>(attr)) {
|
||||||
value = floatLhs.getValue().convertToDouble();
|
value = floatLhs.getValue().convertToDouble();
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
|
@ -3945,7 +3944,7 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Type resultType = getResult().getType();
|
Type resultType = getResult().getType();
|
||||||
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
|
BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
|
||||||
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
||||||
!resultTensorType.hasSizes()) {
|
!resultTensorType.hasSizes()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -3966,11 +3965,11 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto elementType = shapedty.getElementType();
|
auto elementType = shapedty.getElementType();
|
||||||
if (elementType.isa<IntegerType>()) {
|
if (isa<IntegerType>(elementType)) {
|
||||||
Attribute attribute = IntegerAttr::get(elementType, 1);
|
Attribute attribute = IntegerAttr::get(elementType, 1);
|
||||||
return DenseElementsAttr::get(shapedty, attribute);
|
return DenseElementsAttr::get(shapedty, attribute);
|
||||||
}
|
}
|
||||||
if (elementType.isa<FloatType>()) {
|
if (isa<FloatType>(elementType)) {
|
||||||
Attribute attribute = FloatAttr::get(elementType, 1.0);
|
Attribute attribute = FloatAttr::get(elementType, 1.0);
|
||||||
return DenseElementsAttr::get(shapedty, attribute);
|
return DenseElementsAttr::get(shapedty, attribute);
|
||||||
}
|
}
|
||||||
|
@ -3984,7 +3983,7 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Type resultType = getResult().getType();
|
Type resultType = getResult().getType();
|
||||||
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
|
BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
|
||||||
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
||||||
!resultTensorType.hasSizes()) {
|
!resultTensorType.hasSizes()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -4006,11 +4005,11 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto elementType = shapedty.getElementType();
|
auto elementType = shapedty.getElementType();
|
||||||
if (elementType.isa<IntegerType>()) {
|
if (isa<IntegerType>(elementType)) {
|
||||||
Attribute attribute = IntegerAttr::get(elementType, 0);
|
Attribute attribute = IntegerAttr::get(elementType, 0);
|
||||||
return DenseElementsAttr::get(shapedty, attribute);
|
return DenseElementsAttr::get(shapedty, attribute);
|
||||||
}
|
}
|
||||||
if (elementType.isa<FloatType>()) {
|
if (isa<FloatType>(elementType)) {
|
||||||
Attribute attribute = FloatAttr::get(elementType, 0.0);
|
Attribute attribute = FloatAttr::get(elementType, 0.0);
|
||||||
return DenseElementsAttr::get(shapedty, attribute);
|
return DenseElementsAttr::get(shapedty, attribute);
|
||||||
}
|
}
|
||||||
|
@ -4025,7 +4024,7 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Type resultType = getResult().getType();
|
Type resultType = getResult().getType();
|
||||||
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
|
BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
|
||||||
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
||||||
!resultTensorType.hasSizes()) {
|
!resultTensorType.hasSizes()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -4043,14 +4042,14 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
|
||||||
mlir::RankedTensorType::get(sizes, resultTensorType.getDtype());
|
mlir::RankedTensorType::get(sizes, resultTensorType.getDtype());
|
||||||
|
|
||||||
auto elementType = shapedty.getElementType();
|
auto elementType = shapedty.getElementType();
|
||||||
if (elementType.isa<IntegerType>()) {
|
if (isa<IntegerType>(elementType)) {
|
||||||
int64_t value = 0;
|
int64_t value = 0;
|
||||||
if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) {
|
if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) {
|
||||||
Attribute attribute = IntegerAttr::get(elementType, value);
|
Attribute attribute = IntegerAttr::get(elementType, value);
|
||||||
return DenseElementsAttr::get(shapedty, attribute);
|
return DenseElementsAttr::get(shapedty, attribute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (elementType.isa<FloatType>()) {
|
if (isa<FloatType>(elementType)) {
|
||||||
double value = 0.0;
|
double value = 0.0;
|
||||||
if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
|
if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
|
||||||
Attribute attribute = FloatAttr::get(elementType, value);
|
Attribute attribute = FloatAttr::get(elementType, value);
|
||||||
|
@ -4631,15 +4630,14 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
|
||||||
auto initialize = cast<InitializeGlobalSlotsOp>(getBody()->getTerminator());
|
auto initialize = cast<InitializeGlobalSlotsOp>(getBody()->getTerminator());
|
||||||
for (Attribute symName : initialize.getSlotSymNames()) {
|
for (Attribute symName : initialize.getSlotSymNames()) {
|
||||||
auto wasInserted = initializedGlobalSlots
|
auto wasInserted = initializedGlobalSlots
|
||||||
.insert(symName.cast<FlatSymbolRefAttr>().getAttr())
|
.insert(cast<FlatSymbolRefAttr>(symName).getAttr())
|
||||||
.second;
|
.second;
|
||||||
if (!wasInserted)
|
if (!wasInserted)
|
||||||
return initialize.emitError("duplicate initialization of global slot: ")
|
return initialize.emitError("duplicate initialization of global slot: ")
|
||||||
<< symName;
|
<< symName;
|
||||||
}
|
}
|
||||||
auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) {
|
auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) {
|
||||||
return lhs.cast<StringAttr>().getValue() <
|
return cast<StringAttr>(lhs).getValue() < cast<StringAttr>(rhs).getValue();
|
||||||
rhs.cast<StringAttr>().getValue();
|
|
||||||
};
|
};
|
||||||
auto known = llvm::to_vector(knownGlobalSlots);
|
auto known = llvm::to_vector(knownGlobalSlots);
|
||||||
llvm::sort(known, lessThanByStringValue);
|
llvm::sort(known, lessThanByStringValue);
|
||||||
|
@ -4652,7 +4650,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
|
||||||
InFlightDiagnostic diag = initialize.emitOpError(
|
InFlightDiagnostic diag = initialize.emitOpError(
|
||||||
"must have one initializer for each global slot in the module");
|
"must have one initializer for each global slot in the module");
|
||||||
for (auto knownGlobalSlot : known) {
|
for (auto knownGlobalSlot : known) {
|
||||||
auto symName = FlatSymbolRefAttr::get(knownGlobalSlot.cast<StringAttr>());
|
auto symName = FlatSymbolRefAttr::get(cast<StringAttr>(knownGlobalSlot));
|
||||||
if (!initializedGlobalSlots.count(knownGlobalSlot)) {
|
if (!initializedGlobalSlots.count(knownGlobalSlot)) {
|
||||||
diag.attachNote(
|
diag.attachNote(
|
||||||
symbolTable.lookup<GlobalSlotOp>(symName.getAttr()).getLoc())
|
symbolTable.lookup<GlobalSlotOp>(symName.getAttr()).getLoc())
|
||||||
|
@ -4663,7 +4661,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
|
||||||
if (!knownGlobalSlots.count(initializedGlobalSlot)) {
|
if (!knownGlobalSlots.count(initializedGlobalSlot)) {
|
||||||
diag.attachNote().append(
|
diag.attachNote().append(
|
||||||
"unexpected global slot initializer for non-existent global slot ",
|
"unexpected global slot initializer for non-existent global slot ",
|
||||||
FlatSymbolRefAttr::get(initializedGlobalSlot.cast<StringAttr>()));
|
FlatSymbolRefAttr::get(cast<StringAttr>(initializedGlobalSlot)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return diag;
|
return diag;
|
||||||
|
|
|
@ -29,7 +29,7 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
|
||||||
|
|
||||||
// For a UnionType to be a subtype, all of its contained types must be
|
// For a UnionType to be a subtype, all of its contained types must be
|
||||||
// subtypes.
|
// subtypes.
|
||||||
if (auto unionType = subtype.dyn_cast<UnionType>()) {
|
if (auto unionType = dyn_cast<UnionType>(subtype)) {
|
||||||
for (auto containedType : unionType.getContainedTypes()) {
|
for (auto containedType : unionType.getContainedTypes()) {
|
||||||
if (!isValidSubtype(containedType, type))
|
if (!isValidSubtype(containedType, type))
|
||||||
return false;
|
return false;
|
||||||
|
@ -37,17 +37,17 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto any = type.dyn_cast<AnyType>())
|
if (auto any = dyn_cast<AnyType>(type))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
if (auto number = type.dyn_cast<NumberType>())
|
if (auto number = dyn_cast<NumberType>(type))
|
||||||
return subtype.isa<IntType>() || subtype.isa<Torch::FloatType>();
|
return isa<IntType>(subtype) || isa<Torch::FloatType>(subtype);
|
||||||
|
|
||||||
if (auto optional = type.dyn_cast<OptionalType>())
|
if (auto optional = dyn_cast<OptionalType>(type))
|
||||||
return isValidSubtype(subtype, optional.getContainedType()) ||
|
return isValidSubtype(subtype, optional.getContainedType()) ||
|
||||||
subtype.isa<Torch::NoneType>();
|
isa<Torch::NoneType>(subtype);
|
||||||
|
|
||||||
if (auto unionType = type.dyn_cast<UnionType>()) {
|
if (auto unionType = dyn_cast<UnionType>(type)) {
|
||||||
for (auto containedType : unionType.getContainedTypes()) {
|
for (auto containedType : unionType.getContainedTypes()) {
|
||||||
if (isValidSubtype(subtype, containedType))
|
if (isValidSubtype(subtype, containedType))
|
||||||
return true;
|
return true;
|
||||||
|
@ -55,10 +55,10 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
|
if (auto tuple = dyn_cast<Torch::TupleType>(type)) {
|
||||||
if (!subtype.isa<Torch::TupleType>())
|
if (!isa<Torch::TupleType>(subtype))
|
||||||
return false;
|
return false;
|
||||||
auto subtypes = subtype.cast<Torch::TupleType>().getContainedTypes();
|
auto subtypes = cast<Torch::TupleType>(subtype).getContainedTypes();
|
||||||
auto types = tuple.getContainedTypes();
|
auto types = tuple.getContainedTypes();
|
||||||
if (subtypes.size() != types.size())
|
if (subtypes.size() != types.size())
|
||||||
return false;
|
return false;
|
||||||
|
@ -69,14 +69,14 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto subtypeTensorType = subtype.dyn_cast<BaseTensorType>();
|
auto subtypeTensorType = dyn_cast<BaseTensorType>(subtype);
|
||||||
auto typeTensorType = type.dyn_cast<BaseTensorType>();
|
auto typeTensorType = dyn_cast<BaseTensorType>(type);
|
||||||
if (subtypeTensorType && typeTensorType) {
|
if (subtypeTensorType && typeTensorType) {
|
||||||
// Check that both tensors have the same `BaseTensorType` subtype.
|
// Check that both tensors have the same `BaseTensorType` subtype.
|
||||||
// TODO: This is not subtyping according to PEP 483. See description
|
// TODO: This is not subtyping according to PEP 483. See description
|
||||||
// of NonValueTensorType.
|
// of NonValueTensorType.
|
||||||
if (subtypeTensorType.isa<ValueTensorType>() !=
|
if (isa<ValueTensorType>(subtypeTensorType) !=
|
||||||
typeTensorType.isa<ValueTensorType>())
|
isa<ValueTensorType>(typeTensorType))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// `type` must not have more static information than `subtype`, and `type`
|
// `type` must not have more static information than `subtype`, and `type`
|
||||||
|
@ -181,23 +181,23 @@ void Torch::UnionType::print(AsmPrinter &printer) const {
|
||||||
|
|
||||||
static bool isValidTorchDtype(Type dtype) {
|
static bool isValidTorchDtype(Type dtype) {
|
||||||
// For complex types, get the underlying element type
|
// For complex types, get the underlying element type
|
||||||
if (dtype.isa<ComplexType>()) {
|
if (isa<ComplexType>(dtype)) {
|
||||||
dtype = dtype.cast<ComplexType>().getElementType();
|
dtype = cast<ComplexType>(dtype).getElementType();
|
||||||
}
|
}
|
||||||
// Torch quantized types.
|
// Torch quantized types.
|
||||||
if (dtype.isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt32Type>())
|
if (isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt32Type>(dtype))
|
||||||
return true;
|
return true;
|
||||||
// Builtin floating point types.
|
// Builtin floating point types.
|
||||||
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
|
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
|
||||||
return true;
|
return true;
|
||||||
if (dtype.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
if (dtype.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||||
Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
|
Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
if (dtype.isa<Torch::StringType>())
|
if (isa<Torch::StringType>(dtype))
|
||||||
return true;
|
return true;
|
||||||
// Builtin integer types.
|
// Builtin integer types.
|
||||||
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
|
if (IntegerType type = dyn_cast<IntegerType>(dtype)) {
|
||||||
if (type.isSignless() && type.getWidth() == 1)
|
if (type.isSignless() && type.getWidth() == 1)
|
||||||
return true;
|
return true;
|
||||||
if (type.isSigned()) {
|
if (type.isSigned()) {
|
||||||
|
@ -273,7 +273,7 @@ verifyTensorType(function_ref<InFlightDiagnostic()> emitError,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!optionalSparsity.isa<sparse_tensor::SparseTensorEncodingAttr>()) {
|
if (!isa<sparse_tensor::SparseTensorEncodingAttr>(optionalSparsity)) {
|
||||||
emitError() << "invalid sparsity encoding attribute";
|
emitError() << "invalid sparsity encoding attribute";
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -441,12 +441,12 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
|
static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
|
||||||
if (auto floatType = dtype.dyn_cast<mlir::FloatType>()) {
|
if (auto floatType = dyn_cast<mlir::FloatType>(dtype)) {
|
||||||
return dtype;
|
return dtype;
|
||||||
} else if (auto integerType = dtype.dyn_cast<IntegerType>()) {
|
} else if (auto integerType = dyn_cast<IntegerType>(dtype)) {
|
||||||
return IntegerType::get(context, integerType.getWidth(),
|
return IntegerType::get(context, integerType.getWidth(),
|
||||||
IntegerType::Signless);
|
IntegerType::Signless);
|
||||||
} else if (dtype.isa<mlir::ComplexType>()) {
|
} else if (isa<mlir::ComplexType>(dtype)) {
|
||||||
return dtype;
|
return dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -502,8 +502,8 @@ void ValueTensorType::print(AsmPrinter &printer) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
|
Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
|
||||||
assert(((lhs.isa<ValueTensorType>() && rhs.isa<ValueTensorType>()) ||
|
assert(((isa<ValueTensorType>(lhs) && isa<ValueTensorType>(rhs)) ||
|
||||||
(lhs.isa<NonValueTensorType>() && rhs.isa<NonValueTensorType>())) &&
|
(isa<NonValueTensorType>(lhs) && isa<NonValueTensorType>(rhs))) &&
|
||||||
"expected lhs and rhs to have same sense of value semantics");
|
"expected lhs and rhs to have same sense of value semantics");
|
||||||
|
|
||||||
// First, calculate the dtype.
|
// First, calculate the dtype.
|
||||||
|
@ -566,21 +566,21 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
|
||||||
// linkage) and the predicates themselves can't be added/used in the
|
// linkage) and the predicates themselves can't be added/used in the
|
||||||
// specification of the parameters of the Torch_DictType.
|
// specification of the parameters of the Torch_DictType.
|
||||||
static bool isAnyTorchDictKeyType(Type type) {
|
static bool isAnyTorchDictKeyType(Type type) {
|
||||||
return type.isa<Torch::AnyType>() || type.isa<Torch::IntType>() ||
|
return isa<Torch::AnyType>(type) || isa<Torch::IntType>(type) ||
|
||||||
type.isa<Torch::BoolType>() || type.isa<Torch::FloatType>() ||
|
isa<Torch::BoolType>(type) || isa<Torch::FloatType>(type) ||
|
||||||
type.isa<Torch::StringType>() || type.isa<Torch::BaseTensorType>();
|
isa<Torch::StringType>(type) || isa<Torch::BaseTensorType>(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isAnyTorchType(Type type) {
|
static bool isAnyTorchType(Type type) {
|
||||||
return isValidSubtype(type, Torch::NumberType::get(type.getContext())) ||
|
return isValidSubtype(type, Torch::NumberType::get(type.getContext())) ||
|
||||||
type.isa<Torch::BaseTensorType>() || type.isa<Torch::AnyType>() ||
|
isa<Torch::BaseTensorType>(type) || isa<Torch::AnyType>(type) ||
|
||||||
type.isa<Torch::BoolType>() || type.isa<Torch::DictType>() ||
|
isa<Torch::BoolType>(type) || isa<Torch::DictType>(type) ||
|
||||||
type.isa<Torch::DeviceType>() || type.isa<Torch::GeneratorType>() ||
|
isa<Torch::DeviceType>(type) || isa<Torch::GeneratorType>(type) ||
|
||||||
type.isa<Torch::ListType>() || type.isa<Torch::LinearParamsType>() ||
|
isa<Torch::ListType>(type) || isa<Torch::LinearParamsType>(type) ||
|
||||||
type.isa<Torch::NumberType>() || type.isa<Torch::NnModuleType>() ||
|
isa<Torch::NumberType>(type) || isa<Torch::NnModuleType>(type) ||
|
||||||
type.isa<Torch::NoneType>() || type.isa<Torch::OptionalType>() ||
|
isa<Torch::NoneType>(type) || isa<Torch::OptionalType>(type) ||
|
||||||
type.isa<Torch::StringType>() || type.isa<Torch::TupleType>() ||
|
isa<Torch::StringType>(type) || isa<Torch::TupleType>(type) ||
|
||||||
type.isa<Torch::UnionType>();
|
isa<Torch::UnionType>(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
|
|
|
@ -53,7 +53,7 @@ public:
|
||||||
auto typeBoundAttr =
|
auto typeBoundAttr =
|
||||||
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
|
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
|
||||||
Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type();
|
Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type();
|
||||||
if (!bound.isa<ValueTensorType>())
|
if (!isa<ValueTensorType>(bound))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
func, "unimplemented: preserving aliasing for non-value-semantic "
|
func, "unimplemented: preserving aliasing for non-value-semantic "
|
||||||
"type bounds");
|
"type bounds");
|
||||||
|
@ -72,10 +72,10 @@ public:
|
||||||
|
|
||||||
SmallVector<Type> newResultTypes;
|
SmallVector<Type> newResultTypes;
|
||||||
for (auto type : func.getFunctionType().getResults()) {
|
for (auto type : func.getFunctionType().getResults()) {
|
||||||
if (auto none = type.dyn_cast<Torch::NoneType>()) {
|
if (auto none = dyn_cast<Torch::NoneType>(type)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
|
if (auto tuple = dyn_cast<Torch::TupleType>(type)) {
|
||||||
llvm::append_range(newResultTypes, tuple.getContainedTypes());
|
llvm::append_range(newResultTypes, tuple.getContainedTypes());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -133,12 +133,12 @@ public:
|
||||||
int newOpResultIdx = 0;
|
int newOpResultIdx = 0;
|
||||||
SmallVector<Value> newResults;
|
SmallVector<Value> newResults;
|
||||||
for (auto type : call.getResultTypes()) {
|
for (auto type : call.getResultTypes()) {
|
||||||
if (type.isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(type)) {
|
||||||
newResults.push_back(
|
newResults.push_back(
|
||||||
rewriter.create<ConstantNoneOp>(call.getLoc(), type));
|
rewriter.create<ConstantNoneOp>(call.getLoc(), type));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (type.isa<Torch::TupleType>()) {
|
if (isa<Torch::TupleType>(type)) {
|
||||||
newResults.push_back(rewriter.create<PrimTupleConstructOp>(
|
newResults.push_back(rewriter.create<PrimTupleConstructOp>(
|
||||||
call.getLoc(), type, newCall.getResults()));
|
call.getLoc(), type, newCall.getResults()));
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -1386,7 +1386,7 @@ static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
|
||||||
unNormalizedExp, sum);
|
unNormalizedExp, sum);
|
||||||
if (resultType != accumulatorType)
|
if (resultType != accumulatorType)
|
||||||
result = convertTensorToDtype(rewriter, loc, result,
|
result = convertTensorToDtype(rewriter, loc, result,
|
||||||
resultType.cast<BaseTensorType>().getDtype());
|
cast<BaseTensorType>(resultType).getDtype());
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -1405,7 +1405,7 @@ public:
|
||||||
op, "expected result type to have a dtype");
|
op, "expected result type to have a dtype");
|
||||||
}
|
}
|
||||||
Type resultTensorDtype = resultTensorType.getDtype();
|
Type resultTensorDtype = resultTensorType.getDtype();
|
||||||
if (!resultTensorDtype.isa<mlir::FloatType>())
|
if (!isa<mlir::FloatType>(resultTensorDtype))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only support floating-point type");
|
"Only support floating-point type");
|
||||||
|
|
||||||
|
@ -1980,7 +1980,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
Type dtype = resType.getDtype();
|
Type dtype = resType.getDtype();
|
||||||
if (dtype.isa<mlir::ComplexType>()) {
|
if (isa<mlir::ComplexType>(dtype)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "lowering of aten.linalg_cross for complex inputs dtype is "
|
op, "lowering of aten.linalg_cross for complex inputs dtype is "
|
||||||
"currently unimplemented");
|
"currently unimplemented");
|
||||||
|
@ -2015,7 +2015,7 @@ public:
|
||||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
|
||||||
// idx = torch.arange(3)
|
// idx = torch.arange(3)
|
||||||
auto outType = opType.dyn_cast<BaseTensorType>();
|
auto outType = dyn_cast<BaseTensorType>(opType);
|
||||||
auto arangeType = outType.getWithSizesAndDtype(
|
auto arangeType = outType.getWithSizesAndDtype(
|
||||||
llvm::ArrayRef<int64_t>(3),
|
llvm::ArrayRef<int64_t>(3),
|
||||||
IntegerType::get(op.getContext(), 64, IntegerType::Signed));
|
IntegerType::get(op.getContext(), 64, IntegerType::Signed));
|
||||||
|
@ -5848,7 +5848,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
||||||
Value keepDim = op.getKeepdim();
|
Value keepDim = op.getKeepdim();
|
||||||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||||
Type outputType = op.getType();
|
Type outputType = op.getType();
|
||||||
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
|
BaseTensorType outputTensorType = cast<BaseTensorType>(outputType);
|
||||||
if (!outputTensorType.hasDtype()) {
|
if (!outputTensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"expected result type to have a dtype");
|
"expected result type to have a dtype");
|
||||||
|
@ -5893,7 +5893,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
||||||
Type meanDimResultType = inputTensorTy;
|
Type meanDimResultType = inputTensorTy;
|
||||||
for (unsigned i = 0; i < dimListElements.size(); i++)
|
for (unsigned i = 0; i < dimListElements.size(); i++)
|
||||||
meanDimResultType = computeReductionType(
|
meanDimResultType = computeReductionType(
|
||||||
rewriter, op, meanDimResultType.cast<BaseTensorType>(),
|
rewriter, op, cast<BaseTensorType>(meanDimResultType),
|
||||||
dimListElements[i],
|
dimListElements[i],
|
||||||
/*keepDim=*/true);
|
/*keepDim=*/true);
|
||||||
|
|
||||||
|
@ -6189,7 +6189,7 @@ public:
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Type resultType = op.getType();
|
Type resultType = op.getType();
|
||||||
BaseTensorType resultTensorType = resultType.cast<BaseTensorType>();
|
BaseTensorType resultTensorType = cast<BaseTensorType>(resultType);
|
||||||
if (!resultTensorType.hasDtype()) {
|
if (!resultTensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected result type to have a dtype");
|
op, "expected result type to have a dtype");
|
||||||
|
|
|
@ -207,7 +207,7 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Type resultETy = resultTy.getDtype();
|
Type resultETy = resultTy.getDtype();
|
||||||
if (!resultETy.isa<mlir::FloatType>())
|
if (!isa<mlir::FloatType>(resultETy))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value lhsScale;
|
Value lhsScale;
|
||||||
|
|
|
@ -183,13 +183,13 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
|
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
|
||||||
if (Value value = point.dyn_cast<Value>()) {
|
if (Value value = dyn_cast<Value>(point)) {
|
||||||
bool isSafe = isValueSafeTransferFunction(value);
|
bool isSafe = isValueSafeTransferFunction(value);
|
||||||
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
|
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
|
||||||
propagateIfChanged(state, state->setSafe(isSafe));
|
propagateIfChanged(state, state->setSafe(isSafe));
|
||||||
|
|
||||||
// Handle GlobalSlotGetOp's.
|
// Handle GlobalSlotGetOp's.
|
||||||
if (auto opResult = value.dyn_cast<OpResult>()) {
|
if (auto opResult = dyn_cast<OpResult>(value)) {
|
||||||
if (auto globalSlotGet =
|
if (auto globalSlotGet =
|
||||||
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
|
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
|
||||||
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
|
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||||
|
@ -205,7 +205,7 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
if (auto *genericProgramPoint = point.dyn_cast<GenericProgramPoint *>()) {
|
if (auto *genericProgramPoint = dyn_cast<GenericProgramPoint *>(point)) {
|
||||||
if (auto *flatSymbolRefPoint =
|
if (auto *flatSymbolRefPoint =
|
||||||
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
|
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
|
||||||
if (initializeGlobalSlotsOp) {
|
if (initializeGlobalSlotsOp) {
|
||||||
|
@ -396,7 +396,7 @@ class InlineGlobalSlotsPass
|
||||||
// This could be left to SymbolDCE but it's not hard to do here.
|
// This could be left to SymbolDCE but it's not hard to do here.
|
||||||
for (FlatSymbolRefAttr symName :
|
for (FlatSymbolRefAttr symName :
|
||||||
llvm::map_range(safeToInline, [](Attribute attr) {
|
llvm::map_range(safeToInline, [](Attribute attr) {
|
||||||
return attr.cast<FlatSymbolRefAttr>();
|
return cast<FlatSymbolRefAttr>(attr);
|
||||||
})) {
|
})) {
|
||||||
auto globalSlot =
|
auto globalSlot =
|
||||||
symbolTable.lookup<Torch::GlobalSlotOp>(symName.getValue());
|
symbolTable.lookup<Torch::GlobalSlotOp>(symName.getValue());
|
||||||
|
|
|
@ -46,14 +46,14 @@ static LogicalResult checkType(Operation *op, Type type,
|
||||||
// can statically pattern match and eliminate from the program.
|
// can statically pattern match and eliminate from the program.
|
||||||
// For example, a tensor operand might be optional, and the backend
|
// For example, a tensor operand might be optional, and the backend
|
||||||
// will pattern-match statically whether it is passed as a tensor or None.
|
// will pattern-match statically whether it is passed as a tensor or None.
|
||||||
if (type.isa<Torch::NoneType, Torch::StringType>())
|
if (isa<Torch::NoneType, Torch::StringType>(type))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// We blanket prohibit non-value-semantic tensors.
|
// We blanket prohibit non-value-semantic tensors.
|
||||||
// All of our backends are currently based on value-semantic tensors, so
|
// All of our backends are currently based on value-semantic tensors, so
|
||||||
// we consider it our responsibility to lower all non-value-semantic tensors
|
// we consider it our responsibility to lower all non-value-semantic tensors
|
||||||
// to value-semantic tensors.
|
// to value-semantic tensors.
|
||||||
if (type.isa<NonValueTensorType>()) {
|
if (isa<NonValueTensorType>(type)) {
|
||||||
if (actuallyEmitDiagnostics) {
|
if (actuallyEmitDiagnostics) {
|
||||||
return op
|
return op
|
||||||
->emitError("unsupported by backend contract: non-value tensor type")
|
->emitError("unsupported by backend contract: non-value tensor type")
|
||||||
|
@ -84,7 +84,7 @@ static LogicalResult checkType(Operation *op, Type type,
|
||||||
// have an sufficiently rich system for representing PyTorch type promotion
|
// have an sufficiently rich system for representing PyTorch type promotion
|
||||||
// rules. So we consider it our responsibility to ensure that all dtypes are
|
// rules. So we consider it our responsibility to ensure that all dtypes are
|
||||||
// statically known.
|
// statically known.
|
||||||
if (auto tensorType = type.dyn_cast<ValueTensorType>()) {
|
if (auto tensorType = dyn_cast<ValueTensorType>(type)) {
|
||||||
if (!tensorType.hasSizes()) {
|
if (!tensorType.hasSizes()) {
|
||||||
if (actuallyEmitDiagnostics) {
|
if (actuallyEmitDiagnostics) {
|
||||||
return op
|
return op
|
||||||
|
@ -115,7 +115,7 @@ static LogicalResult checkType(Operation *op, Type type,
|
||||||
// Optional types are also in the category of types which we don't expect
|
// Optional types are also in the category of types which we don't expect
|
||||||
// backends to dynamically compute with, but they can be pattern matched
|
// backends to dynamically compute with, but they can be pattern matched
|
||||||
// in many cases that are practically necessary.
|
// in many cases that are practically necessary.
|
||||||
if (auto optionalType = type.dyn_cast<OptionalType>()) {
|
if (auto optionalType = dyn_cast<OptionalType>(type)) {
|
||||||
// TODO: Be stricter about tensor types.
|
// TODO: Be stricter about tensor types.
|
||||||
// See comment below for ListType.
|
// See comment below for ListType.
|
||||||
if (optionalType.getContainedType().isa<ValueTensorType>())
|
if (optionalType.getContainedType().isa<ValueTensorType>())
|
||||||
|
@ -127,7 +127,7 @@ static LogicalResult checkType(Operation *op, Type type,
|
||||||
// backends to dynamically compute with, but they can be pattern matched
|
// backends to dynamically compute with, but they can be pattern matched
|
||||||
// in many cases that are practically necessary. For example, the
|
// in many cases that are practically necessary. For example, the
|
||||||
// strides of a convolution op are represented as a list.
|
// strides of a convolution op are represented as a list.
|
||||||
if (auto listType = type.dyn_cast<ListType>()) {
|
if (auto listType = dyn_cast<ListType>(type)) {
|
||||||
// TODO: Be stricter about tensor types.
|
// TODO: Be stricter about tensor types.
|
||||||
// For the moment, there are cases (such as for torch.cat) where we end
|
// For the moment, there are cases (such as for torch.cat) where we end
|
||||||
// up with `!torch.list<vtensor>` which doesn't have shape or dtype in
|
// up with `!torch.list<vtensor>` which doesn't have shape or dtype in
|
||||||
|
@ -141,7 +141,7 @@ static LogicalResult checkType(Operation *op, Type type,
|
||||||
// Tuple types are also in the category of types which we don't expect
|
// Tuple types are also in the category of types which we don't expect
|
||||||
// backends to dynamically compute with, but they can be pattern matched
|
// backends to dynamically compute with, but they can be pattern matched
|
||||||
// in many cases that are practically necessary.
|
// in many cases that are practically necessary.
|
||||||
if (auto tupleType = type.dyn_cast<Torch::TupleType>()) {
|
if (auto tupleType = dyn_cast<Torch::TupleType>(type)) {
|
||||||
for (auto containedType : tupleType.getContainedTypes()) {
|
for (auto containedType : tupleType.getContainedTypes()) {
|
||||||
if (failed(checkType(op, containedType, actuallyEmitDiagnostics)))
|
if (failed(checkType(op, containedType, actuallyEmitDiagnostics)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -140,7 +140,7 @@ public:
|
||||||
auto returnOp = ops.returnOp.value();
|
auto returnOp = ops.returnOp.value();
|
||||||
for (auto operand : llvm::enumerate(returnOp->getOperands())) {
|
for (auto operand : llvm::enumerate(returnOp->getOperands())) {
|
||||||
auto type = operand.value().getType();
|
auto type = operand.value().getType();
|
||||||
if (!type.isa<NonValueTensorType>())
|
if (!isa<NonValueTensorType>(type))
|
||||||
continue;
|
continue;
|
||||||
originalReturnTypes[operand.index()] = type;
|
originalReturnTypes[operand.index()] = type;
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,15 +38,15 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
|
static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
|
||||||
if (auto optionalType = type.dyn_cast<OptionalType>()) {
|
if (auto optionalType = dyn_cast<OptionalType>(type)) {
|
||||||
Type newContainedType = getContainerOrTensorTypeWithValueSemantics(
|
Type newContainedType = getContainerOrTensorTypeWithValueSemantics(
|
||||||
optionalType.getContainedType());
|
optionalType.getContainedType());
|
||||||
return OptionalType::get(newContainedType);
|
return OptionalType::get(newContainedType);
|
||||||
} else if (auto listType = type.dyn_cast<ListType>()) {
|
} else if (auto listType = dyn_cast<ListType>(type)) {
|
||||||
Type newContainedType =
|
Type newContainedType =
|
||||||
getContainerOrTensorTypeWithValueSemantics(listType.getContainedType());
|
getContainerOrTensorTypeWithValueSemantics(listType.getContainedType());
|
||||||
return ListType::get(newContainedType);
|
return ListType::get(newContainedType);
|
||||||
} else if (auto tensorType = type.dyn_cast<NonValueTensorType>()) {
|
} else if (auto tensorType = dyn_cast<NonValueTensorType>(type)) {
|
||||||
return tensorType.getWithValueSemantics();
|
return tensorType.getWithValueSemantics();
|
||||||
} else {
|
} else {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -92,10 +92,10 @@ public:
|
||||||
SmallVector<Value> newOperands;
|
SmallVector<Value> newOperands;
|
||||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||||
Type operandType = opOperand.get().getType();
|
Type operandType = opOperand.get().getType();
|
||||||
if (operandType.isa<NonValueTensorType>()) {
|
if (isa<NonValueTensorType>(operandType)) {
|
||||||
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
||||||
opOperand.get()));
|
opOperand.get()));
|
||||||
} else if (auto listType = operandType.dyn_cast<ListType>()) {
|
} else if (auto listType = dyn_cast<ListType>(operandType)) {
|
||||||
if (!(listType.getContainedType().isa<NonValueTensorType>() ||
|
if (!(listType.getContainedType().isa<NonValueTensorType>() ||
|
||||||
listType.getContainedType().isa<OptionalType>()))
|
listType.getContainedType().isa<OptionalType>()))
|
||||||
continue;
|
continue;
|
||||||
|
@ -144,7 +144,7 @@ public:
|
||||||
}
|
}
|
||||||
opOperand.set(rewriter.create<PrimListConstructOp>(
|
opOperand.set(rewriter.create<PrimListConstructOp>(
|
||||||
op->getLoc(), newListType, newListElements));
|
op->getLoc(), newListType, newListElements));
|
||||||
} else if (auto optionalType = operandType.dyn_cast<OptionalType>()) {
|
} else if (auto optionalType = dyn_cast<OptionalType>(operandType)) {
|
||||||
// TODO: A more general way to handle the optional type is to
|
// TODO: A more general way to handle the optional type is to
|
||||||
// introduce a `copy.to_optional_vtensor` op.
|
// introduce a `copy.to_optional_vtensor` op.
|
||||||
if (!optionalType.getContainedType().isa<NonValueTensorType>())
|
if (!optionalType.getContainedType().isa<NonValueTensorType>())
|
||||||
|
@ -450,7 +450,7 @@ struct ReduceOpVariantsPass
|
||||||
auto hasValueSemantics = [](Type t) {
|
auto hasValueSemantics = [](Type t) {
|
||||||
// TODO: Make this an allowlist based on a closed torch dialect
|
// TODO: Make this an allowlist based on a closed torch dialect
|
||||||
// type system.
|
// type system.
|
||||||
if (auto tensorType = t.dyn_cast<NonValueTensorType>()) {
|
if (auto tensorType = dyn_cast<NonValueTensorType>(t)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -170,7 +170,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
if (operandType == desiredType)
|
if (operandType == desiredType)
|
||||||
return operand;
|
return operand;
|
||||||
|
|
||||||
if (desiredType.isa<Torch::AnyType>()) {
|
if (isa<Torch::AnyType>(desiredType)) {
|
||||||
// Generator's are currently passed as Any because TorchScript cannot
|
// Generator's are currently passed as Any because TorchScript cannot
|
||||||
// compile a function with Generator type arguments.
|
// compile a function with Generator type arguments.
|
||||||
// Ignoring that hack, this is a correct handling of Any type should we need
|
// Ignoring that hack, this is a correct handling of Any type should we need
|
||||||
|
@ -180,8 +180,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
|
|
||||||
// The type `!torch.number` can be an `int`, `float`, or `complex`.
|
// The type `!torch.number` can be an `int`, `float`, or `complex`.
|
||||||
// TODO: Add a new type `Torch::ComplexType` to handle the complex case.
|
// TODO: Add a new type `Torch::ComplexType` to handle the complex case.
|
||||||
if (desiredType.isa<Torch::NumberType>() &&
|
if (isa<Torch::NumberType>(desiredType) &&
|
||||||
operandType.isa<Torch::IntType, Torch::FloatType>()) {
|
isa<Torch::IntType, Torch::FloatType>(operandType)) {
|
||||||
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -189,7 +189,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
// `Scalar` inputs. At compile time, such inputs will usually be
|
// `Scalar` inputs. At compile time, such inputs will usually be
|
||||||
// resolved to an `int`, `float`, or `None` so we need to derefine
|
// resolved to an `int`, `float`, or `None` so we need to derefine
|
||||||
// to match the library function signature.
|
// to match the library function signature.
|
||||||
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
|
if (auto unionType = dyn_cast<Torch::UnionType>(desiredType)) {
|
||||||
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
|
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
|
||||||
return containedType
|
return containedType
|
||||||
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
|
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
|
||||||
|
@ -200,8 +200,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
// Operands with type `!torch.none` correspond to library function inputs with
|
// Operands with type `!torch.none` correspond to library function inputs with
|
||||||
// types like `!torch.optional<...>` or `!torch.union<..., none>`, so here the
|
// types like `!torch.optional<...>` or `!torch.union<..., none>`, so here the
|
||||||
// type is derefined to match the expected type of the library function.
|
// type is derefined to match the expected type of the library function.
|
||||||
if (operandType.isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(operandType)) {
|
||||||
assert(!desiredType.isa<Torch::NoneType>() &&
|
assert(!isa<Torch::NoneType>(desiredType) &&
|
||||||
"Don't expect library functions to have NoneType parameters");
|
"Don't expect library functions to have NoneType parameters");
|
||||||
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||||
}
|
}
|
||||||
|
@ -211,8 +211,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
// dtype of input scalars. However, this also means we sometimes have to
|
// dtype of input scalars. However, this also means we sometimes have to
|
||||||
// manually turn `Scalar`s into `float`s when inserting the shape functions
|
// manually turn `Scalar`s into `float`s when inserting the shape functions
|
||||||
// into the IR.
|
// into the IR.
|
||||||
if (operandType.isa<Torch::NumberType>() &&
|
if (isa<Torch::NumberType>(operandType) &&
|
||||||
desiredType.isa<Torch::FloatType>()) {
|
isa<Torch::FloatType>(desiredType)) {
|
||||||
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
|
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,8 +224,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
// type).
|
// type).
|
||||||
// A case where this happens is `!torch.optional<vtensor>` ->
|
// A case where this happens is `!torch.optional<vtensor>` ->
|
||||||
// `!torch.optional<list<int>>>`.
|
// `!torch.optional<list<int>>>`.
|
||||||
if (auto operandOptionalType = operandType.dyn_cast<Torch::OptionalType>()) {
|
if (auto operandOptionalType = dyn_cast<Torch::OptionalType>(operandType)) {
|
||||||
if (desiredType.isa<Torch::OptionalType>()) {
|
if (isa<Torch::OptionalType>(desiredType)) {
|
||||||
// if optional is None:
|
// if optional is None:
|
||||||
// return derefine(None)
|
// return derefine(None)
|
||||||
// else:
|
// else:
|
||||||
|
@ -258,7 +258,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
// If the desired type is OptionalType, then recursively adjust the operand to
|
// If the desired type is OptionalType, then recursively adjust the operand to
|
||||||
// the contained type, then derefine it to `!torch.optional`. For example,
|
// the contained type, then derefine it to `!torch.optional`. For example,
|
||||||
// `!torch.vtensor -> !torch.optional<list<int>>>`.
|
// `!torch.vtensor -> !torch.optional<list<int>>>`.
|
||||||
if (auto desiredOptionalType = desiredType.dyn_cast<Torch::OptionalType>()) {
|
if (auto desiredOptionalType = dyn_cast<Torch::OptionalType>(desiredType)) {
|
||||||
FailureOr<Value> adjusted = adjustFunctionArg(
|
FailureOr<Value> adjusted = adjustFunctionArg(
|
||||||
b, loc, operand, desiredOptionalType.getContainedType(),
|
b, loc, operand, desiredOptionalType.getContainedType(),
|
||||||
baseTransformation);
|
baseTransformation);
|
||||||
|
@ -267,7 +267,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
return b.create<DerefineOp>(loc, desiredType, *adjusted).getResult();
|
return b.create<DerefineOp>(loc, desiredType, *adjusted).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto desiredListType = desiredType.dyn_cast<Torch::ListType>()) {
|
if (auto desiredListType = dyn_cast<Torch::ListType>(desiredType)) {
|
||||||
// Pseudocode:
|
// Pseudocode:
|
||||||
//
|
//
|
||||||
// operand = ...
|
// operand = ...
|
||||||
|
@ -311,7 +311,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
// The library functions use `float` where the operator
|
// The library functions use `float` where the operator
|
||||||
// signature uses `Scalar` (see comments in torch_ods_gen.py for
|
// signature uses `Scalar` (see comments in torch_ods_gen.py for
|
||||||
// explanation).
|
// explanation).
|
||||||
if (desiredType.isa<Torch::FloatType>() &&
|
if (isa<Torch::FloatType>(desiredType) &&
|
||||||
operand.getType().isa<Torch::IntType>()) {
|
operand.getType().isa<Torch::IntType>()) {
|
||||||
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
|
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
||||||
// Turn every tensor into a tuple of (tensor_rank, tensor_dtype)
|
// Turn every tensor into a tuple of (tensor_rank, tensor_dtype)
|
||||||
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
|
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
|
||||||
Type desiredType) -> Value {
|
Type desiredType) -> Value {
|
||||||
if (desiredType.isa<Torch::TupleType>() &&
|
if (isa<Torch::TupleType>(desiredType) &&
|
||||||
operand.getType().isa<Torch::BaseTensorType>()) {
|
operand.getType().isa<Torch::BaseTensorType>()) {
|
||||||
Type intType = Torch::IntType::get(b.getContext());
|
Type intType = Torch::IntType::get(b.getContext());
|
||||||
Type sizeListType = Torch::ListType::get(intType);
|
Type sizeListType = Torch::ListType::get(intType);
|
||||||
|
|
|
@ -38,7 +38,7 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
||||||
Type desiredType) -> Value {
|
Type desiredType) -> Value {
|
||||||
// The shape library functions have tensor operands replaced with
|
// The shape library functions have tensor operands replaced with
|
||||||
// `!torch.list<int>` types for the shape. Get the sizes.
|
// `!torch.list<int>` types for the shape. Get the sizes.
|
||||||
auto desiredListType = desiredType.dyn_cast<Torch::ListType>();
|
auto desiredListType = dyn_cast<Torch::ListType>(desiredType);
|
||||||
if (!desiredListType)
|
if (!desiredListType)
|
||||||
return operand;
|
return operand;
|
||||||
if (operand.getType().isa<Torch::BaseTensorType>() &&
|
if (operand.getType().isa<Torch::BaseTensorType>() &&
|
||||||
|
|
|
@ -262,13 +262,13 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
|
||||||
originalResultType.template dyn_cast<BaseTensorType>()) {
|
originalResultType.template dyn_cast<BaseTensorType>()) {
|
||||||
// If we didn't get any new information, there is nothing left for us to do.
|
// If we didn't get any new information, there is nothing left for us to do.
|
||||||
updatedType = meetTensorTypes(originalBaseTensorType,
|
updatedType = meetTensorTypes(originalBaseTensorType,
|
||||||
newResultType.cast<BaseTensorType>());
|
cast<BaseTensorType>(newResultType));
|
||||||
if (!updatedType || updatedType == originalBaseTensorType)
|
if (!updatedType || updatedType == originalBaseTensorType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
calculateOp, "New type information does not refine old type");
|
calculateOp, "New type information does not refine old type");
|
||||||
} else if (auto originalResultType =
|
} else if (auto originalResultType =
|
||||||
result.getType().template dyn_cast<Torch::NumberType>()) {
|
result.getType().template dyn_cast<Torch::NumberType>()) {
|
||||||
if (!newResultType.isa<Torch::FloatType, Torch::IntType>()) {
|
if (!isa<Torch::FloatType, Torch::IntType>(newResultType)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
calculateOp,
|
calculateOp,
|
||||||
"Refinement of `NumberType` must be a `FloatType` or `IntType`");
|
"Refinement of `NumberType` must be a `FloatType` or `IntType`");
|
||||||
|
@ -291,10 +291,10 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
|
||||||
}
|
}
|
||||||
if (!originalTypedValue) {
|
if (!originalTypedValue) {
|
||||||
rewriter.setInsertionPointAfter(calculateOp);
|
rewriter.setInsertionPointAfter(calculateOp);
|
||||||
if (originalResultType.isa<BaseTensorType>()) {
|
if (isa<BaseTensorType>(originalResultType)) {
|
||||||
originalTypedValue = rewriter.create<TensorStaticInfoCastOp>(
|
originalTypedValue = rewriter.create<TensorStaticInfoCastOp>(
|
||||||
loc, originalResultType, result);
|
loc, originalResultType, result);
|
||||||
} else if (originalResultType.isa<Torch::NumberType>()) {
|
} else if (isa<Torch::NumberType>(originalResultType)) {
|
||||||
originalTypedValue =
|
originalTypedValue =
|
||||||
rewriter.create<DerefineOp>(loc, originalResultType, result);
|
rewriter.create<DerefineOp>(loc, originalResultType, result);
|
||||||
} else {
|
} else {
|
||||||
|
@ -314,14 +314,14 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
|
||||||
OpOperand &use = yieldValues->getOpOperand(resultNum);
|
OpOperand &use = yieldValues->getOpOperand(resultNum);
|
||||||
Value def = use.get();
|
Value def = use.get();
|
||||||
Value newYieldedValue;
|
Value newYieldedValue;
|
||||||
if (def.isa<OpResult>() &&
|
if (isa<OpResult>(def) &&
|
||||||
def.cast<OpResult>()
|
cast<OpResult>(def)
|
||||||
.getDefiningOp()
|
.getDefiningOp()
|
||||||
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
|
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
|
||||||
newYieldedValue = def;
|
newYieldedValue = def;
|
||||||
} else {
|
} else {
|
||||||
rewriter.setInsertionPoint(yieldValues);
|
rewriter.setInsertionPoint(yieldValues);
|
||||||
if (updatedType.isa<BaseTensorType>()) {
|
if (isa<BaseTensorType>(updatedType)) {
|
||||||
newYieldedValue =
|
newYieldedValue =
|
||||||
rewriter.create<TensorStaticInfoCastOp>(loc, updatedType, def);
|
rewriter.create<TensorStaticInfoCastOp>(loc, updatedType, def);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -53,8 +53,9 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
|
||||||
op, "Failed to convert `dtypeScalarType` to a builtin type");
|
op, "Failed to convert `dtypeScalarType` to a builtin type");
|
||||||
}
|
}
|
||||||
impliedTypeFromDtype =
|
impliedTypeFromDtype =
|
||||||
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
|
cast<BaseTensorType>(originalResultType)
|
||||||
originalResultType.getOptionalSizes(), *builtinType);
|
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
|
||||||
|
*builtinType);
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Unimplemented: Expected result type to "
|
"Unimplemented: Expected result type to "
|
||||||
|
@ -179,7 +180,7 @@ public:
|
||||||
}
|
}
|
||||||
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
|
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
|
||||||
auto impliedTypeFromInputType =
|
auto impliedTypeFromInputType =
|
||||||
originalResultType.cast<BaseTensorType>()
|
cast<BaseTensorType>(originalResultType)
|
||||||
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
|
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
|
||||||
inputType)
|
inputType)
|
||||||
.cast<BaseTensorType>();
|
.cast<BaseTensorType>();
|
||||||
|
|
|
@ -98,7 +98,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
||||||
|
|
||||||
auto originalResultType = result.getType().cast<BaseTensorType>();
|
auto originalResultType = result.getType().cast<BaseTensorType>();
|
||||||
auto impliedTypesFromShape =
|
auto impliedTypesFromShape =
|
||||||
originalResultType.cast<BaseTensorType>()
|
cast<BaseTensorType>(originalResultType)
|
||||||
.getWithSizesAndDtype(ArrayRef(sizes),
|
.getWithSizesAndDtype(ArrayRef(sizes),
|
||||||
originalResultType.getOptionalDtype())
|
originalResultType.getOptionalDtype())
|
||||||
.cast<BaseTensorType>();
|
.cast<BaseTensorType>();
|
||||||
|
|
|
@ -70,8 +70,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
||||||
return torch_upstream::ScalarType::QInt8;
|
return torch_upstream::ScalarType::QInt8;
|
||||||
if (type.isa<QInt32Type>())
|
if (type.isa<QInt32Type>())
|
||||||
return torch_upstream::ScalarType::QInt32;
|
return torch_upstream::ScalarType::QInt32;
|
||||||
if (type.isa<ComplexType>()) {
|
if (isa<ComplexType>(type)) {
|
||||||
mlir::Type complexElemType = type.cast<ComplexType>().getElementType();
|
mlir::Type complexElemType = cast<ComplexType>(type).getElementType();
|
||||||
if (complexElemType.isF16())
|
if (complexElemType.isF16())
|
||||||
return torch_upstream::ScalarType::ComplexHalf;
|
return torch_upstream::ScalarType::ComplexHalf;
|
||||||
if (complexElemType.isF32())
|
if (complexElemType.isF32())
|
||||||
|
@ -84,9 +84,9 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
||||||
Type Torch::getTypeForTorchType(
|
Type Torch::getTypeForTorchType(
|
||||||
MLIRContext *context, Type type,
|
MLIRContext *context, Type type,
|
||||||
mlir::IntegerType::SignednessSemantics signedness) {
|
mlir::IntegerType::SignednessSemantics signedness) {
|
||||||
if (type.isa<Torch::IntType>())
|
if (isa<Torch::IntType>(type))
|
||||||
return IntegerType::get(context, 64, signedness);
|
return IntegerType::get(context, 64, signedness);
|
||||||
if (type.isa<Torch::FloatType>())
|
if (isa<Torch::FloatType>(type))
|
||||||
return Float64Type::get(context);
|
return Float64Type::get(context);
|
||||||
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
|
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
|
||||||
}
|
}
|
||||||
|
@ -150,14 +150,14 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
|
||||||
|
|
||||||
Type Torch::getDefaultDtypeForTorchScalar(Type type) {
|
Type Torch::getDefaultDtypeForTorchScalar(Type type) {
|
||||||
MLIRContext *context = type.getContext();
|
MLIRContext *context = type.getContext();
|
||||||
if (type.isa<Torch::FloatType>()) {
|
if (isa<Torch::FloatType>(type)) {
|
||||||
// For now, use float32 which is the initial default dtype returned by
|
// For now, use float32 which is the initial default dtype returned by
|
||||||
// `torch.get_default_dtype`.
|
// `torch.get_default_dtype`.
|
||||||
return Float32Type::get(context);
|
return Float32Type::get(context);
|
||||||
}
|
}
|
||||||
if (type.isa<Torch::IntType>())
|
if (isa<Torch::IntType>(type))
|
||||||
return IntegerType::get(context, 64, IntegerType::Signed);
|
return IntegerType::get(context, 64, IntegerType::Signed);
|
||||||
if (type.isa<Torch::BoolType>())
|
if (isa<Torch::BoolType>(type))
|
||||||
return IntegerType::get(context, 1);
|
return IntegerType::get(context, 1);
|
||||||
llvm_unreachable(
|
llvm_unreachable(
|
||||||
"getDefaultDtypeForTorchScalar called on an unsupported type");
|
"getDefaultDtypeForTorchScalar called on an unsupported type");
|
||||||
|
@ -165,11 +165,11 @@ Type Torch::getDefaultDtypeForTorchScalar(Type type) {
|
||||||
|
|
||||||
Type Torch::getBuiltInTypeForTorchScalar(Type type) {
|
Type Torch::getBuiltInTypeForTorchScalar(Type type) {
|
||||||
MLIRContext *context = type.getContext();
|
MLIRContext *context = type.getContext();
|
||||||
if (type.isa<Torch::FloatType>())
|
if (isa<Torch::FloatType>(type))
|
||||||
return Float64Type::get(context);
|
return Float64Type::get(context);
|
||||||
if (type.isa<Torch::IntType>())
|
if (isa<Torch::IntType>(type))
|
||||||
return IntegerType::get(context, 64, IntegerType::Signed);
|
return IntegerType::get(context, 64, IntegerType::Signed);
|
||||||
if (type.isa<Torch::BoolType>())
|
if (isa<Torch::BoolType>(type))
|
||||||
return IntegerType::get(context, 1);
|
return IntegerType::get(context, 1);
|
||||||
llvm_unreachable(
|
llvm_unreachable(
|
||||||
"getBuiltInTypeForTorchScalar called on an unsupported type");
|
"getBuiltInTypeForTorchScalar called on an unsupported type");
|
||||||
|
|
|
@ -62,15 +62,14 @@ Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder,
|
||||||
Attribute value,
|
Attribute value,
|
||||||
Type type,
|
Type type,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
if (auto integerType = type.dyn_cast<Torch::IntType>())
|
if (auto integerType = dyn_cast<Torch::IntType>(type))
|
||||||
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>());
|
return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value));
|
||||||
|
|
||||||
if (auto floatType = type.dyn_cast<Torch::FloatType>())
|
if (auto floatType = dyn_cast<Torch::FloatType>(type))
|
||||||
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>());
|
return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value));
|
||||||
|
|
||||||
if (type.isa<Torch::BoolType>()) {
|
if (isa<Torch::BoolType>(type)) {
|
||||||
return builder.create<Torch::ConstantBoolOp>(loc,
|
return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value));
|
||||||
value.cast<IntegerAttr>());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return arith::ConstantOp::materialize(builder, value, type, loc);
|
return arith::ConstantOp::materialize(builder, value, type, loc);
|
||||||
|
|
|
@ -95,7 +95,7 @@ public:
|
||||||
|
|
||||||
// get outputs
|
// get outputs
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType(0));
|
Type newResultType = getTypeConverter()->convertType(op.getType(0));
|
||||||
auto resultType = newResultType.cast<RankedTensorType>();
|
auto resultType = cast<RankedTensorType>(newResultType);
|
||||||
if (!resultType) {
|
if (!resultType) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ class VerifyStablehloBackendContractPass
|
||||||
converter.addConversion([](Type type) -> Type {
|
converter.addConversion([](Type type) -> Type {
|
||||||
auto elemTy = type;
|
auto elemTy = type;
|
||||||
if (isa<TensorType>(type))
|
if (isa<TensorType>(type))
|
||||||
elemTy = type.cast<TensorType>().getElementType();
|
elemTy = cast<TensorType>(type).getElementType();
|
||||||
if (BaseMemRefType::isValidElementType(elemTy))
|
if (BaseMemRefType::isValidElementType(elemTy))
|
||||||
return type;
|
return type;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -54,11 +54,11 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static bool isArgMemRefTypeValid(Type type) {
|
static bool isArgMemRefTypeValid(Type type) {
|
||||||
if (auto memRefType = type.dyn_cast<MemRefType>()) {
|
if (auto memRefType = dyn_cast<MemRefType>(type)) {
|
||||||
Type elemTy = memRefType.getElementType();
|
Type elemTy = memRefType.getElementType();
|
||||||
if (elemTy.isa<Float16Type, Float32Type, Float64Type>()) {
|
if (elemTy.isa<Float16Type, Float32Type, Float64Type>()) {
|
||||||
return true;
|
return true;
|
||||||
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
|
} else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) {
|
||||||
if (integerTy.isSignlessInteger(64))
|
if (integerTy.isSignlessInteger(64))
|
||||||
return true;
|
return true;
|
||||||
if (integerTy.isSignlessInteger(32))
|
if (integerTy.isSignlessInteger(32))
|
||||||
|
@ -69,7 +69,7 @@ static bool isArgMemRefTypeValid(Type type) {
|
||||||
return true;
|
return true;
|
||||||
if (integerTy.isSignlessInteger(1))
|
if (integerTy.isSignlessInteger(1))
|
||||||
return true;
|
return true;
|
||||||
} else if (auto complexTy = elemTy.dyn_cast<ComplexType>()) {
|
} else if (auto complexTy = dyn_cast<ComplexType>(elemTy)) {
|
||||||
return complexTy.getElementType().isa<Float32Type, Float64Type>();
|
return complexTy.getElementType().isa<Float32Type, Float64Type>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -81,7 +81,7 @@ static void addEmitCInterfaceAttr(func::FuncOp func) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type getAbiTypeForMemRef(Type type) {
|
static Type getAbiTypeForMemRef(Type type) {
|
||||||
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
|
return UnrankedMemRefType::get(cast<MemRefType>(type).getElementType(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to get the type string for one return value like i32, f64,
|
// Helper function to get the type string for one return value like i32, f64,
|
||||||
|
@ -90,12 +90,12 @@ static Type getAbiTypeForMemRef(Type type) {
|
||||||
static std::string getTypeToken(Type type) {
|
static std::string getTypeToken(Type type) {
|
||||||
if (type.isSignlessInteger())
|
if (type.isSignlessInteger())
|
||||||
return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
|
return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
|
||||||
else if (type.isa<mlir::FloatType>())
|
else if (isa<mlir::FloatType>(type))
|
||||||
return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
|
return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
|
||||||
else if (auto complexTy = type.dyn_cast<mlir::ComplexType>())
|
else if (auto complexTy = dyn_cast<mlir::ComplexType>(type))
|
||||||
return ("c" + Twine(complexTy.getElementType().getIntOrFloatBitWidth()))
|
return ("c" + Twine(complexTy.getElementType().getIntOrFloatBitWidth()))
|
||||||
.str();
|
.str();
|
||||||
else if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
|
else if (auto memRefType = dyn_cast<UnrankedMemRefType>(type))
|
||||||
return "mr" + getTypeToken(memRefType.getElementType());
|
return "mr" + getTypeToken(memRefType.getElementType());
|
||||||
|
|
||||||
llvm_unreachable(
|
llvm_unreachable(
|
||||||
|
@ -171,7 +171,7 @@ static LogicalResult mungeFunction(
|
||||||
for (auto en : llvm::enumerate(types)) {
|
for (auto en : llvm::enumerate(types)) {
|
||||||
Type retType = en.value();
|
Type retType = en.value();
|
||||||
Value retVal = op.getOperand(en.index());
|
Value retVal = op.getOperand(en.index());
|
||||||
if (auto memrefReturnType = retType.dyn_cast<MemRefType>()) {
|
if (auto memrefReturnType = dyn_cast<MemRefType>(retType)) {
|
||||||
auto elemType = memrefReturnType.getElementType();
|
auto elemType = memrefReturnType.getElementType();
|
||||||
retType = UnrankedMemRefType::get(elemType, 0);
|
retType = UnrankedMemRefType::get(elemType, 0);
|
||||||
// Cast to unranked memref type before sending it as a function
|
// Cast to unranked memref type before sending it as a function
|
||||||
|
|
Loading…
Reference in New Issue