Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3243)

Like #3130, gradually replace the deprecated code

https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
pull/3244/head
penguin_wwy 2024-04-28 05:00:56 +08:00 committed by GitHub
parent 466618e45e
commit 6679728c56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
56 changed files with 936 additions and 983 deletions

View File

@ -23,7 +23,7 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
int64_t dimA, int64_t dimB, int64_t dimA, int64_t dimB,
Value &transposed) { Value &transposed) {
Type transposedType; Type transposedType;
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(), if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
dimA, dimB, transposedType))) dimA, dimB, transposedType)))
return failure(); return failure();
Value cstDimA = rewriter.create<Torch::ConstantIntOp>( Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
@ -554,7 +554,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// conversions which are not supported in Torch-MLIR right now. // conversions which are not supported in Torch-MLIR right now.
Torch::ValueTensorType targetTy = Torch::ValueTensorType targetTy =
target.getType().cast<Torch::ValueTensorType>(); cast<Torch::ValueTensorType>(target.getType());
if (!targetTy.hasDtype()) { if (!targetTy.hasDtype()) {
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"target tensor must have a dtype"); "target tensor must have a dtype");
@ -753,9 +753,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
Type listElemType = Type listElemType =
tensors[0] cast<Torch::BaseTensorType>(tensors[0].getType())
.getType()
.cast<Torch::BaseTensorType>()
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr); /*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType); Type listType = Torch::ListType::get(listElemType);
@ -869,7 +867,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>(); auto weightTensorType = cast<Torch::ValueTensorType>(weight.getType());
if (!weightTensorType || !weightTensorType.hasSizes()) { if (!weightTensorType || !weightTensorType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "Expected weight type having sizes"); binder.op, "Expected weight type having sizes");
@ -1188,7 +1186,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>(); auto weightTensorType = cast<Torch::ValueTensorType>(weight.getType());
if (!weightTensorType || !weightTensorType.hasSizes()) { if (!weightTensorType || !weightTensorType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "Expected weight type having sizes"); binder.op, "Expected weight type having sizes");
@ -1427,7 +1425,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.customOpNameStringAttr(mode, "mode", "DCR") || binder.customOpNameStringAttr(mode, "mode", "DCR") ||
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
auto inputTy = input.getType().dyn_cast<Torch::BaseTensorType>(); auto inputTy = dyn_cast<Torch::BaseTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes()) { if (!inputTy || !inputTy.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes"); binder.op, "Expected input type having sizes");
@ -1536,9 +1534,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value scale = operands[1]; Value scale = operands[1];
Value zeropoint = operands[2]; Value zeropoint = operands[2];
auto operandTy = operand.getType().cast<Torch::ValueTensorType>(); auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>(); auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
if (!scaleTy || !scaleTy.hasSizes()) if (!scaleTy || !scaleTy.hasSizes())
return rewriter.notifyMatchFailure(binder.op, "requires known rank"); return rewriter.notifyMatchFailure(binder.op, "requires known rank");
if (!resultType.hasDtype()) if (!resultType.hasDtype())
@ -1611,7 +1609,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]); ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
Value trainVal = operands[2]; Value trainVal = operands[2];
auto trainTensorType = auto trainTensorType =
trainVal.getType().dyn_cast<Torch::BaseTensorType>(); dyn_cast<Torch::BaseTensorType>(trainVal.getType());
if (!trainTensorType) if (!trainTensorType)
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"train tensor must have a type"); "train tensor must have a type");
@ -1629,8 +1627,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (auto valueTensorLiteralOp = if (auto valueTensorLiteralOp =
trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) { trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue() auto val = cast<DenseElementsAttr>(valueTensorLiteralOp.getValue())
.cast<DenseElementsAttr>()
.getSplatValue<bool>(); .getSplatValue<bool>();
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val); trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val);
} else { } else {
@ -2072,7 +2069,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes(); dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
SmallVector<Value> dimList; SmallVector<Value> dimList;
Torch::BaseTensorType shapeType = Torch::BaseTensorType shapeType =
shape.getType().cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(shape.getType());
Type selectResultType = rewriter.getType<Torch::ValueTensorType>( Type selectResultType = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({}), shapeType.getOptionalDtype()); ArrayRef<int64_t>({}), shapeType.getOptionalDtype());
Value zero = rewriter.create<Torch::ConstantIntOp>( Value zero = rewriter.create<Torch::ConstantIntOp>(

View File

@ -104,10 +104,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "operand grid_sampler bind failure"); binder.op, "operand grid_sampler bind failure");
auto inputTensorType = input.getType().cast<Torch::ValueTensorType>(); auto inputTensorType = cast<Torch::ValueTensorType>(input.getType());
ArrayRef<int64_t> inputShape = inputTensorType.getSizes(); ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
uint32_t inputRank = inputShape.size(); uint32_t inputRank = inputShape.size();
auto gridTensorType = grid.getType().cast<Torch::ValueTensorType>(); auto gridTensorType = cast<Torch::ValueTensorType>(grid.getType());
ArrayRef<int64_t> gridShape = gridTensorType.getSizes(); ArrayRef<int64_t> gridShape = gridTensorType.getSizes();
uint32_t gridRank = gridShape.size(); uint32_t gridRank = gridShape.size();
@ -233,7 +233,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
axis = rank + axis; axis = rank + axis;
} }
// need input type and sizes to flatten/unflatten later. // need input type and sizes to flatten/unflatten later.
auto inputTy = input.getType().cast<Torch::ValueTensorType>(); auto inputTy = cast<Torch::ValueTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes()) if (!inputTy || !inputTy.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "failed to get input type or sizes"); binder.op, "failed to get input type or sizes");
@ -1065,7 +1065,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
auto transpose = [&](Value m) -> Value { auto transpose = [&](Value m) -> Value {
auto tty = m.getType().cast<Torch::ValueTensorType>(); auto tty = cast<Torch::ValueTensorType>(m.getType());
auto shape = tty.getOptionalSizes(); auto shape = tty.getOptionalSizes();
if (shape.has_value()) { if (shape.has_value()) {
llvm::SmallVector<int64_t> newShape(shape.value()); llvm::SmallVector<int64_t> newShape(shape.value());
@ -1134,7 +1134,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>(); auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
if (!inputTensorType || !inputTensorType.hasSizes()) { if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes"); binder.op, "Expected input type having sizes");
@ -1228,7 +1228,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rank = *maybeRank; rank = *maybeRank;
SmallVector<Value> normalized; SmallVector<Value> normalized;
axis = Torch::toPositiveDim(axis, rank); axis = Torch::toPositiveDim(axis, rank);
auto xType = x.getType().cast<Torch::ValueTensorType>(); auto xType = cast<Torch::ValueTensorType>(x.getType());
if (!xType.hasSizes()) { if (!xType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "Expected input (X) to have sizes"); binder.op, "Expected input (X) to have sizes");
@ -1307,7 +1307,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
// Get pads shape and rank. The pads tensor is expected to be 1-D // Get pads shape and rank. The pads tensor is expected to be 1-D
// tensor. // tensor.
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>(); auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
if (!padsTensorType || !padsTensorType.hasSizes()) { if (!padsTensorType || !padsTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"Expect non empty pad tensor"); "Expect non empty pad tensor");
@ -1323,7 +1323,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
// As per onnx.Pad documentation, padSize = 2*num_data_axes // As per onnx.Pad documentation, padSize = 2*num_data_axes
// (if axes param not passed). Need to be updated when adding // (if axes param not passed). Need to be updated when adding
// support for `axes` param. // support for `axes` param.
auto dataOpTy = data.getType().cast<Torch::ValueTensorType>(); auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
TensorType dataTensor = dataOpTy.toBuiltinTensor(); TensorType dataTensor = dataOpTy.toBuiltinTensor();
if (!dataTensor || !dataTensor.hasRank()) if (!dataTensor || !dataTensor.hasRank())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1350,7 +1350,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
} }
if (!constantValue) { if (!constantValue) {
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>(); auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
if (dataTensorType.getDtype().isa<IntegerType>()) if (dataTensorType.getDtype().isa<IntegerType>())
constantValue = rewriter.create<Torch::ConstantIntOp>( constantValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0)); loc, rewriter.getI64IntegerAttr(0));

View File

@ -54,7 +54,7 @@ LogicalResult reducedSumImpl(OpBinder binder,
SmallVector<Value> axesList; SmallVector<Value> axesList;
Value axesVal; Value axesVal;
if (!binder.tensorOperandAtIndex(axesVal, 1)) { if (!binder.tensorOperandAtIndex(axesVal, 1)) {
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>(); auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
if (!inputType.hasSizes() || !resultType.hasSizes()) { if (!inputType.hasSizes() || !resultType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected input and result to have shapes"); binder.op, "unimplemented: expected input and result to have shapes");
@ -97,7 +97,7 @@ LogicalResult reducedSumImpl(OpBinder binder,
} }
if (axesList.empty()) { if (axesList.empty()) {
Torch::BaseTensorType axesType = Torch::BaseTensorType axesType =
axesVal.getType().cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(axesVal.getType());
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType()); auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
auto axesShape = axesTy.getSizes(); auto axesShape = axesTy.getSizes();
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
@ -177,7 +177,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value scale = operands[1]; Value scale = operands[1];
Value zeropoint = operands[2]; Value zeropoint = operands[2];
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>(); auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
if (!scaleTy || !scaleTy.hasSizes()) if (!scaleTy || !scaleTy.hasSizes())
return rewriter.notifyMatchFailure(binder.op, "requires known rank"); return rewriter.notifyMatchFailure(binder.op, "requires known rank");
if (!resultType.hasDtype()) if (!resultType.hasDtype())
@ -241,7 +241,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value c = operands.size() == 9 ? operands[8] : nullptr; Value c = operands.size() == 9 ? operands[8] : nullptr;
auto check = [](Value v) { auto check = [](Value v) {
auto vTy = v.getType().cast<Torch::ValueTensorType>(); auto vTy = cast<Torch::ValueTensorType>(v.getType());
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; }); return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
}; };
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
@ -250,7 +250,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "not supported for non per-tensor quantization"); binder.op, "not supported for non per-tensor quantization");
auto extract = [&rewriter, &binder](Value v) { auto extract = [&rewriter, &binder](Value v) {
auto vTy = v.getType().cast<Torch::ValueTensorType>(); auto vTy = cast<Torch::ValueTensorType>(v.getType());
Type extractTy = rewriter.getType<Torch::FloatType>(); Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(vTy.getDtype())) if (isa<IntegerType>(vTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>(); extractTy = rewriter.getType<Torch::IntType>();
@ -268,7 +268,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto make = [&rewriter, &binder](Value v, Value scale, auto make = [&rewriter, &binder](Value v, Value scale,
Value zp) -> Value { Value zp) -> Value {
auto ty = v.getType().cast<Torch::ValueTensorType>(); auto ty = cast<Torch::ValueTensorType>(v.getType());
auto newTy = getQTorchTypeFromTorchIntType(ty); auto newTy = getQTorchTypeFromTorchIntType(ty);
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>( return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), newTy, v, scale, zp); binder.getLoc(), newTy, v, scale, zp);
@ -351,7 +351,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value cZp = operands[7]; Value cZp = operands[7];
auto check = [](Value v) { auto check = [](Value v) {
auto vTy = v.getType().cast<Torch::ValueTensorType>(); auto vTy = cast<Torch::ValueTensorType>(v.getType());
for (auto dim : vTy.getSizes()) for (auto dim : vTy.getSizes())
if (dim != 1) if (dim != 1)
return false; return false;
@ -368,7 +368,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.getType<Torch::IntType>()), rewriter.getType<Torch::IntType>()),
ValueRange{}); ValueRange{});
auto extract = [&rewriter, &binder, &emptyList](Value v) { auto extract = [&rewriter, &binder, &emptyList](Value v) {
auto vTy = v.getType().cast<Torch::ValueTensorType>(); auto vTy = cast<Torch::ValueTensorType>(v.getType());
if (!vTy.getSizes().empty()) { if (!vTy.getSizes().empty()) {
vTy = rewriter.getType<Torch::ValueTensorType>( vTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({}), vTy.getOptionalDtype()); ArrayRef<int64_t>({}), vTy.getOptionalDtype());
@ -393,7 +393,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto make = [&rewriter, &binder](Value v, Value scale, auto make = [&rewriter, &binder](Value v, Value scale,
Value zp) -> Value { Value zp) -> Value {
auto ty = v.getType().cast<Torch::ValueTensorType>(); auto ty = cast<Torch::ValueTensorType>(v.getType());
auto newTy = getQTorchTypeFromTorchIntType(ty); auto newTy = getQTorchTypeFromTorchIntType(ty);
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>( return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), newTy, v, scale, zp); binder.getLoc(), newTy, v, scale, zp);
@ -667,7 +667,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return failure(); return failure();
Value data = inputOperands[0]; Value data = inputOperands[0];
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>(); auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
if (!inputType.hasSizes() || !resultType.hasSizes()) if (!inputType.hasSizes() || !resultType.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, binder.op,
@ -718,7 +718,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
if (dimList.empty()) { if (dimList.empty()) {
Value axes = inputOperands[1]; Value axes = inputOperands[1];
Torch::BaseTensorType axesType = Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(axes.getType());
SmallVector<int64_t> selectSizes{1}; SmallVector<int64_t> selectSizes{1};
Type selectResultType = axesType.getWithSizesAndDtype( Type selectResultType = axesType.getWithSizesAndDtype(
selectSizes, axesType.getOptionalDtype()); selectSizes, axesType.getOptionalDtype());
@ -760,7 +760,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
if (binder.tensorOperands(data, axes) || if (binder.tensorOperands(data, axes) ||
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>(); auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
if (!inputType.hasSizes() || !resultType.hasSizes()) if (!inputType.hasSizes() || !resultType.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, binder.op,
@ -925,8 +925,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// Perform an AtenToDtype op on the squared sum of the operand, stored // Perform an AtenToDtype op on the squared sum of the operand, stored
// now in operand itself. // now in operand itself.
auto size = operand.getType() auto size = dyn_cast<Torch::ValueTensorType>(operand.getType())
.dyn_cast<Torch::ValueTensorType>()
.getOptionalSizes(); .getOptionalSizes();
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>( auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
size, rewriter.getF32Type()); size, rewriter.getF32Type());
@ -1005,7 +1004,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value axesVal; Value axesVal;
if (!binder.tensorOperandAtIndex(axesVal, 1)) { if (!binder.tensorOperandAtIndex(axesVal, 1)) {
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>(); auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
if (!inputType.hasSizes() || !resultType.hasSizes()) { if (!inputType.hasSizes() || !resultType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, binder.op,
@ -1053,7 +1052,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
if (axesList.empty()) { if (axesList.empty()) {
Torch::BaseTensorType axesType = Torch::BaseTensorType axesType =
axesVal.getType().cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(axesVal.getType());
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType()); auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
auto axesShape = axesTy.getSizes(); auto axesShape = axesTy.getSizes();
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
@ -1191,7 +1190,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// Extract the axes values from the axes operand: // Extract the axes values from the axes operand:
if (!binder.tensorOperandAtIndex(axes, 1)) { if (!binder.tensorOperandAtIndex(axes, 1)) {
Torch::BaseTensorType axesType = Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(axes.getType());
SmallVector<int64_t> selectSizes{1}; SmallVector<int64_t> selectSizes{1};
Type selectResultType = axesType.getWithSizesAndDtype( Type selectResultType = axesType.getWithSizesAndDtype(
selectSizes, axesType.getOptionalDtype()); selectSizes, axesType.getOptionalDtype());
@ -1344,7 +1343,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// Extract the axes values from the axes operand: // Extract the axes values from the axes operand:
if (!binder.tensorOperandAtIndex(axes, 1)) { if (!binder.tensorOperandAtIndex(axes, 1)) {
Torch::BaseTensorType axesType = Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(axes.getType());
SmallVector<int64_t> selectSizes{1}; SmallVector<int64_t> selectSizes{1};
Type selectResultType = axesType.getWithSizesAndDtype( Type selectResultType = axesType.getWithSizesAndDtype(
selectSizes, axesType.getOptionalDtype()); selectSizes, axesType.getOptionalDtype());
@ -1467,12 +1466,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto loc = binder.getLoc(); auto loc = binder.getLoc();
auto result0Ty = auto result0Ty =
binder.op->getResult(0).getType().cast<Torch::ValueTensorType>(); cast<Torch::ValueTensorType>(binder.op->getResult(0).getType());
auto resultNTy = binder.op->getResults() auto resultNTy = cast<Torch::ValueTensorType>(
.back() binder.op->getResults().back().getType());
.getType() auto selfTy = cast<Torch::ValueTensorType>(self.getType());
.cast<Torch::ValueTensorType>();
auto selfTy = self.getType().cast<Torch::ValueTensorType>();
int64_t dim = axis; int64_t dim = axis;
if (dim < 0) if (dim < 0)
@ -1555,7 +1552,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "Failed to get num_outputs attribute"); binder.op, "Failed to get num_outputs attribute");
auto result0Ty = auto result0Ty =
binder.op->getResult(0).getType().cast<Torch::ValueTensorType>(); cast<Torch::ValueTensorType>(binder.op->getResult(0).getType());
auto selfTy = auto selfTy =
cast<Torch::ValueTensorType>(binder.op->getOperand(0).getType()); cast<Torch::ValueTensorType>(binder.op->getOperand(0).getType());
@ -1617,7 +1614,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
if (binder.tensorOperand(operand) || if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
auto operandType = operand.getType().cast<Torch::ValueTensorType>(); auto operandType = cast<Torch::ValueTensorType>(operand.getType());
TensorType tensorType = operandType.toBuiltinTensor(); TensorType tensorType = operandType.toBuiltinTensor();
if (!tensorType || !tensorType.hasRank()) if (!tensorType || !tensorType.hasRank())
return failure(); return failure();
@ -1705,26 +1702,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
} }
auto context = rewriter.getContext(); auto context = rewriter.getContext();
auto operandTorchTy = operand.getType().cast<Torch::ValueTensorType>(); auto operandTorchTy = cast<Torch::ValueTensorType>(operand.getType());
auto operandTy = auto operandTy =
operandTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(operandTorchTy.toBuiltinTensor());
if (!operandTy) if (!operandTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, binder.op,
"Expected tensor operator argument to be a ranked tensor type"); "Expected tensor operator argument to be a ranked tensor type");
auto startsTorchTy = starts.getType().cast<Torch::ValueTensorType>(); auto startsTorchTy = cast<Torch::ValueTensorType>(starts.getType());
auto startsTy = auto startsTy =
startsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(startsTorchTy.toBuiltinTensor());
int startSize = startsTy.getDimSize(0); int startSize = startsTy.getDimSize(0);
auto endsTorchTy = ends.getType().cast<Torch::ValueTensorType>(); auto endsTorchTy = cast<Torch::ValueTensorType>(ends.getType());
auto endsTy = auto endsTy = dyn_cast<RankedTensorType>(endsTorchTy.toBuiltinTensor());
endsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
int endSize = endsTy.getDimSize(0); int endSize = endsTy.getDimSize(0);
auto resultTy = auto resultTy =
resultTorchType.toBuiltinTensor().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(resultTorchType.toBuiltinTensor());
if (!resultTy) if (!resultTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "Expected result type to be a ranked tensor type"); binder.op, "Expected result type to be a ranked tensor type");
@ -1768,9 +1764,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"and their dimensions to match"); "and their dimensions to match");
if (axes) { if (axes) {
auto axesTorchTy = axes.getType().cast<Torch::ValueTensorType>(); auto axesTorchTy = cast<Torch::ValueTensorType>(axes.getType());
auto axesTy = auto axesTy =
axesTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(axesTorchTy.toBuiltinTensor());
int64_t numAxes = axesTy.getDimSize(0); int64_t numAxes = axesTy.getDimSize(0);
if (!(axesTy && numAxes == endSize)) if (!(axesTy && numAxes == endSize))
@ -1792,7 +1788,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
auto select = [&](Value v, Value k) -> Value { auto select = [&](Value v, Value k) -> Value {
auto ty = v.getType().cast<Torch::ValueTensorType>(); auto ty = cast<Torch::ValueTensorType>(v.getType());
auto sel = rewriter.create<Torch::AtenIndexSelectOp>( auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
loc, loc,
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1}, Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
@ -1872,7 +1868,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
} }
Torch::BaseTensorType shapeType = Torch::BaseTensorType shapeType =
shape.getType().cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(shape.getType());
SmallVector<Value> dimList; SmallVector<Value> dimList;
SmallVector<int64_t> selectSizes; SmallVector<int64_t> selectSizes;
selectSizes.push_back(1); selectSizes.push_back(1);
@ -2007,7 +2003,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// instead of using the dynamic axes at operand[1]. // instead of using the dynamic axes at operand[1].
if (!binder.tensorOperandAtIndex(axes, 1)) { if (!binder.tensorOperandAtIndex(axes, 1)) {
Torch::BaseTensorType axesType = Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(axes.getType());
auto sizes = axesType.getSizes(); auto sizes = axesType.getSizes();
for (int i = 0; i < sizes[0]; i++) { for (int i = 0; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>( Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
@ -2136,7 +2132,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// int32, int64 Assuming start, limit and delta to be same type (could // int32, int64 Assuming start, limit and delta to be same type (could
// they be different?) // they be different?)
Torch::BaseTensorType startTensorType = Torch::BaseTensorType startTensorType =
start.getType().cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(start.getType());
bool isFloatDType = startTensorType.getDtype().isF64() || bool isFloatDType = startTensorType.getDtype().isF64() ||
startTensorType.getDtype().isF32(); startTensorType.getDtype().isF32();
bool isIntDType = startTensorType.getDtype().isInteger(16) || bool isIntDType = startTensorType.getDtype().isInteger(16) ||
@ -2222,7 +2218,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
SmallVector<int64_t> selectSizes; SmallVector<int64_t> selectSizes;
selectSizes.push_back(1); selectSizes.push_back(1);
Torch::BaseTensorType shapeType = Torch::BaseTensorType shapeType =
repeatDims.getType().cast<Torch::BaseTensorType>(); cast<Torch::BaseTensorType>(repeatDims.getType());
Type selectResultType = shapeType.getWithSizesAndDtype( Type selectResultType = shapeType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
Value zero = rewriter.create<Torch::ConstantIntOp>( Value zero = rewriter.create<Torch::ConstantIntOp>(

View File

@ -95,7 +95,7 @@ public:
Value input = adaptor.getA(); Value input = adaptor.getA();
Type resultType = Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType()); this->getTypeConverter()->convertType(op->getResult(0).getType());
if (!input.getType().isa<mlir::FloatType>()) if (!isa<mlir::FloatType>(input.getType()))
input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type()); input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type());
Value result = rewriter.create<UnaryOp>(loc, input); Value result = rewriter.create<UnaryOp>(loc, input);
rewriter.replaceOp(op, rewriter.replaceOp(op,
@ -172,8 +172,8 @@ public:
matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor, matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext(); MLIRContext *context = op->getContext();
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) { if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) { if (auto type = dyn_cast<RankedTensorType>(elements.getType())) {
Type elemTy = op.getValueAttr().getElementType(); Type elemTy = op.getValueAttr().getElementType();
unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
Type builtinTensorElemTy = IntegerType::get(context, bitWidth); Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
@ -187,9 +187,9 @@ public:
} }
} }
if (auto elements = if (auto elements =
op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) { dyn_cast<DenseResourceElementsAttr>(op.getValueAttr())) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) { if (auto type = dyn_cast<RankedTensorType>(elements.getType())) {
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) { if (auto intType = dyn_cast<IntegerType>(type.getElementType())) {
Type builtinTensorElemTy = Type builtinTensorElemTy =
IntegerType::get(context, intType.getIntOrFloatBitWidth()); IntegerType::get(context, intType.getIntOrFloatBitWidth());
auto shapedType = auto shapedType =

View File

@ -49,8 +49,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
SmallVector<Value> &strides) { SmallVector<Value> &strides) {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
RankedTensorType inputType = RankedTensorType inputType = cast<RankedTensorType>(input.getType());
input.getType().template cast<RankedTensorType>();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@ -73,8 +72,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value builtinTypeStart = adaptor.getStart(); Value builtinTypeStart = adaptor.getStart();
Value builtinTypeEnd = adaptor.getEnd(); Value builtinTypeEnd = adaptor.getEnd();
if (torchTypeStart.getType().isa<OptionalType>() || if (isa<OptionalType>(torchTypeStart.getType()) ||
torchTypeEnd.getType().isa<OptionalType>()) isa<OptionalType>(torchTypeEnd.getType()))
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep()); Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep());
@ -84,7 +83,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
// We cannot use to positive valid dim as for negative strides we need to // We cannot use to positive valid dim as for negative strides we need to
// clamp to `-1` so that the full tensor bounds are available: // clamp to `-1` so that the full tensor bounds are available:
Value end = builtinTypeEnd; Value end = builtinTypeEnd;
if (torchTypeEnd.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(torchTypeEnd.getType())) {
end = dimSize; end = dimSize;
} else { } else {
end = castIntToIndex(rewriter, loc, end); end = castIntToIndex(rewriter, loc, end);
@ -594,7 +593,7 @@ public:
int64_t endDim; int64_t endDim;
if (!matchPattern(op.getEndDim(), m_TorchConstantInt(&endDim))) if (!matchPattern(op.getEndDim(), m_TorchConstantInt(&endDim)))
return rewriter.notifyMatchFailure(op, "end_dim must be constant"); return rewriter.notifyMatchFailure(op, "end_dim must be constant");
auto type = adaptor.getSelf().getType().cast<RankedTensorType>(); auto type = cast<RankedTensorType>(adaptor.getSelf().getType());
auto inputRank = type.getRank(); auto inputRank = type.getRank();
if (inputRank == 1) { if (inputRank == 1) {
// If input rank is equal to 1, then there's no scope for flattening the // If input rank is equal to 1, then there's no scope for flattening the
@ -604,7 +603,7 @@ public:
} }
auto resultType = auto resultType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (startDim < 0) if (startDim < 0)
startDim += inputRank; startDim += inputRank;
if (endDim < 0) if (endDim < 0)
@ -652,7 +651,7 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value self = op.getSelf(); Value self = op.getSelf();
BaseTensorType outputTensorType = op.getType().cast<BaseTensorType>(); BaseTensorType outputTensorType = cast<BaseTensorType>(op.getType());
if (!outputTensorType.hasSizes()) if (!outputTensorType.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: output must have known sizes"); op, "unimplemented: output must have known sizes");
@ -660,7 +659,7 @@ public:
std::optional<unsigned> maybeRank = getTensorRank(self); std::optional<unsigned> maybeRank = getTensorRank(self);
if (!maybeRank) if (!maybeRank)
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>(); auto inputTensorType = cast<Torch::ValueTensorType>(self.getType());
if (!inputTensorType || !inputTensorType.hasSizes()) { if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Expected input type having sizes"); "Expected input type having sizes");
@ -901,7 +900,7 @@ public:
getInputAndOutputShape(Value inputTorchTensor, getInputAndOutputShape(Value inputTorchTensor,
SmallVector<Value> outputSizeTorchInt) { SmallVector<Value> outputSizeTorchInt) {
SmallVector<int64_t> inputShape( SmallVector<int64_t> inputShape(
inputTorchTensor.getType().cast<BaseTensorType>().getSizes()); cast<BaseTensorType>(inputTorchTensor.getType()).getSizes());
SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize); SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize);
for (auto [outputDim, outputDimSize] : for (auto [outputDim, outputDimSize] :
llvm::enumerate(outputSizeTorchInt)) { llvm::enumerate(outputSizeTorchInt)) {
@ -945,11 +944,11 @@ public:
return failure(); return failure();
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank(); int64_t inputRank = inputType.getRank();
const TypeConverter *typeConverter = getTypeConverter(); const TypeConverter *typeConverter = getTypeConverter();
auto resultType = auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
int64_t resultRank = resultType.getRank(); int64_t resultRank = resultType.getRank();
if (resultRank == 0) { if (resultRank == 0) {
rewriter rewriter
@ -1349,7 +1348,7 @@ public:
auto outputDims = b.create<tensor::FromElementsOp>(ty, sizes); auto outputDims = b.create<tensor::FromElementsOp>(ty, sizes);
auto resultType = auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(op, resultType, self, rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(op, resultType, self,
outputDims); outputDims);
return success(); return success();
@ -1367,13 +1366,13 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto inputShape = inputType.getShape(); auto inputShape = inputType.getShape();
int64_t inputRank = inputType.getRank(); int64_t inputRank = inputType.getRank();
const TypeConverter *typeConverter = getTypeConverter(); const TypeConverter *typeConverter = getTypeConverter();
auto resultType = auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto resultShape = resultType.getShape(); auto resultShape = resultType.getShape();
int64_t resultRank = resultType.getRank(); int64_t resultRank = resultType.getRank();
@ -1437,7 +1436,7 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank(); int64_t inputRank = inputType.getRank();
if (inputRank == 0) { if (inputRank == 0) {
@ -1460,7 +1459,7 @@ public:
const TypeConverter *typeConverter = getTypeConverter(); const TypeConverter *typeConverter = getTypeConverter();
auto resultType = auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
int64_t resultRank = resultType.getRank(); int64_t resultRank = resultType.getRank();
// If the dim(th) dimension of operand tensor type is not statically unit, // If the dim(th) dimension of operand tensor type is not statically unit,
@ -1510,7 +1509,7 @@ public:
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant"); return rewriter.notifyMatchFailure(op, "dim must be constant");
auto inputRank = auto inputRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank(); cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
dim = toPositiveDim(dim, inputRank + 1); dim = toPositiveDim(dim, inputRank + 1);
if (!isValidDim(dim, inputRank + 1)) if (!isValidDim(dim, inputRank + 1))
return rewriter.notifyMatchFailure(op, "dim is statically invalid"); return rewriter.notifyMatchFailure(op, "dim is statically invalid");
@ -1535,9 +1534,8 @@ public:
} }
} }
} }
auto resultType = getTypeConverter() auto resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultType, adaptor.getSelf(), reassociationMap); op, resultType, adaptor.getSelf(), reassociationMap);
return success(); return success();
@ -1564,11 +1562,10 @@ public:
return rewriter.notifyMatchFailure(op, "dim1 must be constant"); return rewriter.notifyMatchFailure(op, "dim1 must be constant");
auto inVector = adaptor.getSelf(); auto inVector = adaptor.getSelf();
auto inType = inVector.getType().cast<RankedTensorType>(); auto inType = cast<RankedTensorType>(inVector.getType());
auto inputRank = inType.getRank(); auto inputRank = inType.getRank();
auto outType = getTypeConverter() auto outType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
auto elementType = inType.getElementType(); auto elementType = inType.getElementType();
dim0 = toPositiveDim(dim0, inputRank); dim0 = toPositiveDim(dim0, inputRank);
@ -1634,11 +1631,10 @@ public:
return rewriter.notifyMatchFailure(op, "all dimensions must be constant"); return rewriter.notifyMatchFailure(op, "all dimensions must be constant");
Value inVector = adaptor.getSelf(); Value inVector = adaptor.getSelf();
auto inType = inVector.getType().cast<RankedTensorType>(); auto inType = cast<RankedTensorType>(inVector.getType());
int64_t inputRank = inType.getRank(); int64_t inputRank = inType.getRank();
auto outType = getTypeConverter() auto outType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Type elementType = inType.getElementType(); Type elementType = inType.getElementType();
// Check if the dimensions are a valid constants. // Check if the dimensions are a valid constants.
@ -1747,7 +1743,7 @@ public:
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
RankedTensorType newResultType = RankedTensorType newResultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
int rank = newResultType.getRank(); int rank = newResultType.getRank();
Value dimValue = op.getDim(); Value dimValue = op.getDim();
int64_t dim; int64_t dim;
@ -1802,7 +1798,7 @@ public:
// which in this case is `inShapeConverted` because this shape will yield // which in this case is `inShapeConverted` because this shape will yield
// us the dimension size of the output. // us the dimension size of the output.
SmallVector<bool> useBroadcastToShape; SmallVector<bool> useBroadcastToShape;
int64_t inputRank = self.getType().cast<RankedTensorType>().getRank(); int64_t inputRank = cast<RankedTensorType>(self.getType()).getRank();
for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e; for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e;
++i) { ++i) {
int64_t dim; int64_t dim;
@ -1821,7 +1817,7 @@ public:
SmallVector<Value> inShapeConverted = getTypeConvertedValues( SmallVector<Value> inShapeConverted = getTypeConvertedValues(
rewriter, op.getLoc(), getTypeConverter(), inShape); rewriter, op.getLoc(), getTypeConverter(), inShape);
auto newResultType = auto newResultType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
Value result; Value result;
if (failed(torch_to_linalg::broadcastToGivenShape( if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, self, inShapeConverted, newResultType, result, op, rewriter, self, inShapeConverted, newResultType, result,
@ -1869,7 +1865,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
Value src = adaptor.getSrc(); Value src = adaptor.getSrc();
RankedTensorType selfType = self.getType().cast<RankedTensorType>(); RankedTensorType selfType = cast<RankedTensorType>(self.getType());
// The non_blocking should be a constant `False`. // The non_blocking should be a constant `False`.
bool nonBlocking; bool nonBlocking;
@ -1954,7 +1950,7 @@ public:
} }
Value src = adaptor.getSrc(); Value src = adaptor.getSrc();
auto srcType = src.getType().cast<RankedTensorType>(); auto srcType = cast<RankedTensorType>(src.getType());
int64_t srcRank = srcType.getRank(); int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize); SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
// TODO: audit possibility of sparsity on these tensor // TODO: audit possibility of sparsity on these tensor
@ -1992,7 +1988,7 @@ public:
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
RankedTensorType resultType = RankedTensorType resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto elementType = resultType.getElementType(); auto elementType = resultType.getElementType();
SmallVector<Value> resultShape; SmallVector<Value> resultShape;
@ -2070,9 +2066,9 @@ public:
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
RankedTensorType resultType = RankedTensorType resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType = cast<RankedTensorType>(input.getType());
auto inputElementType = getElementTypeOrSelf(input.getType()); auto inputElementType = getElementTypeOrSelf(input.getType());
if (!isa<ComplexType>(inputElementType)) { if (!isa<ComplexType>(inputElementType)) {
return op.emitError("only ComplexType is allowed as input type"); return op.emitError("only ComplexType is allowed as input type");
@ -2157,7 +2153,7 @@ public:
return rewriter.notifyMatchFailure(op, "dim2 must be constant"); return rewriter.notifyMatchFailure(op, "dim2 must be constant");
Value inputMatrix = adaptor.getSelf(); Value inputMatrix = adaptor.getSelf();
RankedTensorType inputType = inputMatrix.getType().cast<RankedTensorType>(); RankedTensorType inputType = cast<RankedTensorType>(inputMatrix.getType());
int64_t inputRank = inputType.getRank(); int64_t inputRank = inputType.getRank();
if (inputRank < 2) if (inputRank < 2)
@ -2277,7 +2273,7 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern<AtenDiagEmbedOp> {
static SmallVector<Value> static SmallVector<Value>
getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor,
int64_t offset, int64_t dim1, int64_t dim2) { int64_t offset, int64_t dim1, int64_t dim2) {
auto inputType = tensor.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(tensor.getType());
auto inputRank = inputType.getRank(); auto inputRank = inputType.getRank();
// output tensor always has 1 extra dimension // output tensor always has 1 extra dimension
@ -2314,7 +2310,7 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank(); auto inputRank = inputType.getRank();
auto resultRank = inputRank + 1; auto resultRank = inputRank + 1;

View File

@ -80,7 +80,7 @@ public:
if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant"); return op.emitError("unimplemented: dim is not constant");
int64_t inputRank = int64_t inputRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank(); cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
dim = toPositiveDim(dim, inputRank); dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid"); return rewriter.notifyMatchFailure(op, "dim is statically invalid");
@ -88,7 +88,7 @@ public:
Value indices = adaptor.getIndex(); Value indices = adaptor.getIndex();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
RankedTensorType newResultTy = RankedTensorType newResultTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
int64_t rank = newResultTy.getRank(); int64_t rank = newResultTy.getRank();
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, indices); SmallVector<Value> sizes = getTensorSizes(rewriter, loc, indices);
@ -128,9 +128,9 @@ public:
Value weight = adaptor.getWeight(); Value weight = adaptor.getWeight();
Value indices = adaptor.getIndices(); Value indices = adaptor.getIndices();
RankedTensorType newResultType = RankedTensorType newResultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto weightTy = weight.getType().cast<RankedTensorType>(); auto weightTy = cast<RankedTensorType>(weight.getType());
if (weightTy.getRank() != 2) if (weightTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "weight must be rank 2"); return rewriter.notifyMatchFailure(op, "weight must be rank 2");
Value embeddingDim = getDimOp(rewriter, loc, weight, 1); Value embeddingDim = getDimOp(rewriter, loc, weight, 1);
@ -140,7 +140,7 @@ public:
sizes.push_back(embeddingDim); sizes.push_back(embeddingDim);
int64_t resultRank = sizes.size(); int64_t resultRank = sizes.size();
auto indicesTy = indices.getType().cast<RankedTensorType>(); auto indicesTy = cast<RankedTensorType>(indices.getType());
int64_t indicesRank = indicesTy.getRank(); int64_t indicesRank = indicesTy.getRank();
SmallVector<AffineExpr> indicesExprs; SmallVector<AffineExpr> indicesExprs;
for (int i = 0; i < indicesRank; i++) for (int i = 0; i < indicesRank; i++)
@ -274,15 +274,15 @@ public:
"include_last_offset is expected to be a constant boolean value."); "include_last_offset is expected to be a constant boolean value.");
} }
auto weightTy = weight.getType().cast<RankedTensorType>(); auto weightTy = cast<RankedTensorType>(weight.getType());
if (weightTy.getRank() != 2) if (weightTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "weight must be rank 2"); return rewriter.notifyMatchFailure(op, "weight must be rank 2");
auto indicesTy = indices.getType().cast<RankedTensorType>(); auto indicesTy = cast<RankedTensorType>(indices.getType());
if (indicesTy.getRank() != 1) if (indicesTy.getRank() != 1)
return rewriter.notifyMatchFailure(op, "indices must be a vector"); return rewriter.notifyMatchFailure(op, "indices must be a vector");
auto offsetsTy = offsets.getType().cast<RankedTensorType>(); auto offsetsTy = cast<RankedTensorType>(offsets.getType());
if (offsetsTy.getRank() != 1) if (offsetsTy.getRank() != 1)
return rewriter.notifyMatchFailure(op, "offsets much be a vector"); return rewriter.notifyMatchFailure(op, "offsets much be a vector");
@ -471,10 +471,9 @@ public:
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
Value indices = adaptor.getIndex(); Value indices = adaptor.getIndex();
auto indicesTy = cast<RankedTensorType>(indices.getType()); auto indicesTy = cast<RankedTensorType>(indices.getType());
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType = cast<RankedTensorType>(input.getType());
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Type elementType = resultType.getElementType(); Type elementType = resultType.getElementType();
unsigned inputRank = inputType.getRank(); unsigned inputRank = inputType.getRank();
@ -604,10 +603,9 @@ public:
op, "aten.index.Tensor: index tensor must not be None"); op, "aten.index.Tensor: index tensor must not be None");
} }
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType = cast<RankedTensorType>(input.getType());
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Type elementType = resultType.getElementType(); Type elementType = resultType.getElementType();
int inputRank = inputType.getRank(); int inputRank = inputType.getRank();
int resultRank = resultType.getRank(); int resultRank = resultType.getRank();
@ -625,7 +623,7 @@ public:
int maxRank = -1; int maxRank = -1;
for (auto indexTensor : indexTensors) { for (auto indexTensor : indexTensors) {
RankedTensorType indexTensorType = RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>(); cast<RankedTensorType>(indexTensor.getType());
maxRank = std::max(maxRank, (int)indexTensorType.getRank()); maxRank = std::max(maxRank, (int)indexTensorType.getRank());
} }
@ -639,7 +637,7 @@ public:
int64_t staticDimSize = -1; int64_t staticDimSize = -1;
for (auto indexTensor : indexTensors) { for (auto indexTensor : indexTensors) {
RankedTensorType indexTensorType = RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>(); cast<RankedTensorType>(indexTensor.getType());
int64_t indexTensorRank = indexTensorType.getRank(); int64_t indexTensorRank = indexTensorType.getRank();
if ((maxRank - indexTensorRank) > (i - startIndex)) if ((maxRank - indexTensorRank) > (i - startIndex))
continue; continue;
@ -714,7 +712,7 @@ public:
for (auto indexTensor : indexTensors) { for (auto indexTensor : indexTensors) {
RankedTensorType indexTensorType = RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>(); cast<RankedTensorType>(indexTensor.getType());
auto indexTensorShape = auto indexTensorShape =
makeShapeTorchCompatible(indexTensorType.getShape()); makeShapeTorchCompatible(indexTensorType.getShape());
int rank = indexTensorShape.size(); int rank = indexTensorShape.size();
@ -828,7 +826,7 @@ public:
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
Type resultType = getTypeConverter()->convertType(op.getResult().getType()); Type resultType = getTypeConverter()->convertType(op.getResult().getType());
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank(); auto inputRank = inputType.getRank();
Type elementType = inputType.getElementType(); Type elementType = inputType.getElementType();
@ -989,7 +987,7 @@ public:
Value gradOutput = adaptor.getGradOutput(); Value gradOutput = adaptor.getGradOutput();
Type resultType = getTypeConverter()->convertType(op.getResult().getType()); Type resultType = getTypeConverter()->convertType(op.getResult().getType());
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>(); auto gradOutputType = cast<RankedTensorType>(gradOutput.getType());
auto gradOutputRank = gradOutputType.getRank(); auto gradOutputRank = gradOutputType.getRank();
Type elementType = gradOutputType.getElementType(); Type elementType = gradOutputType.getElementType();

View File

@ -48,7 +48,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits); minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric( arg = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{arg}, rewriter, loc, ValueRange{arg},
arg.getType().cast<TensorType>().getElementType(), cast<TensorType>(arg.getType()).getElementType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) { [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = Value result =
rewriter.create<arith::AddIOp>(loc, payloadArgs[0], minSIValue); rewriter.create<arith::AddIOp>(loc, payloadArgs[0], minSIValue);
@ -58,7 +58,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
static Value transposeValue(Location loc, Value value, ArrayRef<int64_t> perms, static Value transposeValue(Location loc, Value value, ArrayRef<int64_t> perms,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto valueTy = value.getType().cast<RankedTensorType>(); auto valueTy = cast<RankedTensorType>(value.getType());
auto inShape = valueTy.getShape(); auto inShape = valueTy.getShape();
llvm::SmallVector<int64_t> outShape; llvm::SmallVector<int64_t> outShape;
llvm::SmallVector<Value> dynDims; llvm::SmallVector<Value> dynDims;
@ -100,8 +100,8 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>(); RankedTensorType lhsType = cast<RankedTensorType>(lhs.getType());
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>(); RankedTensorType rhsType = cast<RankedTensorType>(rhs.getType());
if (lhsType.getRank() != 2 || rhsType.getRank() != 2) { if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -109,9 +109,9 @@ public:
} }
ValueTensorType lhsTorchType = ValueTensorType lhsTorchType =
op.getSelf().getType().cast<ValueTensorType>(); cast<ValueTensorType>(op.getSelf().getType());
ValueTensorType rhsTorchType = ValueTensorType rhsTorchType =
op.getMat2().getType().cast<ValueTensorType>(); cast<ValueTensorType>(op.getMat2().getType());
Value lhsZeroPoint, rhsZeroPoint; Value lhsZeroPoint, rhsZeroPoint;
getZeroPoint(op.getSelf(), lhsZeroPoint); getZeroPoint(op.getSelf(), lhsZeroPoint);
@ -148,7 +148,7 @@ public:
"mismatching contracting dimension for torch.aten.mm")); "mismatching contracting dimension for torch.aten.mm"));
} }
auto resultTy = op.getType().cast<ValueTensorType>(); auto resultTy = cast<ValueTensorType>(op.getType());
auto resultDTy = resultTy.toBuiltinTensor().getElementType(); auto resultDTy = resultTy.toBuiltinTensor().getElementType();
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = cast<TensorType>(newResultType).getElementType(); Type elementType = cast<TensorType>(newResultType).getElementType();
@ -176,9 +176,9 @@ public:
// change uint8 quantization -> int8 quantization // change uint8 quantization -> int8 quantization
int64_t numBits = int64_t numBits =
lhsType.getElementType().cast<mlir::IntegerType>().getWidth(); cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits); signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth(); numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
matmul = matmul =
@ -229,9 +229,9 @@ public:
MLIRContext *context = op.getContext(); MLIRContext *context = op.getContext();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfRank = auto selfRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank(); cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
Type elementType = Type elementType =
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
Value c1 = Value c1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1)); rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
@ -299,8 +299,8 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
return failure(); return failure();
} }
auto lhsType = lhs.getType().cast<RankedTensorType>(); auto lhsType = cast<RankedTensorType>(lhs.getType());
auto rhsType = rhs.getType().cast<RankedTensorType>(); auto rhsType = cast<RankedTensorType>(rhs.getType());
auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType()); auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType());
auto rhsTorchType = cast<ValueTensorType>(op.getOther().getType()); auto rhsTorchType = cast<ValueTensorType>(op.getOther().getType());
@ -348,9 +348,9 @@ public:
// change uint8 quantization -> int8 quantization // change uint8 quantization -> int8 quantization
int64_t numBits = int64_t numBits =
lhsType.getElementType().cast<mlir::IntegerType>().getWidth(); cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits); signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth(); numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
// for quantized vec-vec, vec-mat, and mat-vec cases, lower to // for quantized vec-vec, vec-mat, and mat-vec cases, lower to
@ -726,8 +726,8 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
Value rhs = adaptor.getMat2(); Value rhs = adaptor.getMat2();
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>(); RankedTensorType lhsType = cast<RankedTensorType>(lhs.getType());
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>(); RankedTensorType rhsType = cast<RankedTensorType>(rhs.getType());
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
Type resultElementType = Type resultElementType =
cast<RankedTensorType>(newResultType).getElementType(); cast<RankedTensorType>(newResultType).getElementType();
@ -794,7 +794,7 @@ public:
Value input = adaptor.getInput(); /* in form of N*C*H*W */ Value input = adaptor.getInput(); /* in form of N*C*H*W */
Value weight = adaptor.getWeight(); /* in form of F*C*H*W */ Value weight = adaptor.getWeight(); /* in form of F*C*H*W */
Value bias = adaptor.getBias(); Value bias = adaptor.getBias();
auto resultTy = op.getType().cast<ValueTensorType>(); auto resultTy = cast<ValueTensorType>(op.getType());
Value inputZp, weightZp; Value inputZp, weightZp;
if (auto make = op.getInput() if (auto make = op.getInput()
@ -826,7 +826,7 @@ public:
} }
if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) { if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) {
auto biasDTy = bias.getType().cast<RankedTensorType>().getElementType(); auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType();
if (!biasDTy.isInteger(32)) { if (!biasDTy.isInteger(32)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "quantized result ty should be i32 accumulator"); op, "quantized result ty should be i32 accumulator");
@ -838,15 +838,15 @@ public:
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: only constant transposed supported"); op, "unimplemented: only constant transposed supported");
auto inputDTy = input.getType().cast<RankedTensorType>().getElementType(); auto inputDTy = cast<RankedTensorType>(input.getType()).getElementType();
auto weightDTy = weight.getType().cast<RankedTensorType>().getElementType(); auto weightDTy = cast<RankedTensorType>(weight.getType()).getElementType();
auto resultDTy = resultTy.toBuiltinTensor().getElementType(); auto resultDTy = resultTy.toBuiltinTensor().getElementType();
if (!isa<mlir::FloatType, mlir::IntegerType>(inputDTy) || if (!isa<mlir::FloatType, mlir::IntegerType>(inputDTy) ||
!isa<mlir::FloatType, mlir::IntegerType>(weightDTy) || !isa<mlir::FloatType, mlir::IntegerType>(weightDTy) ||
!isa<mlir::FloatType, mlir::IntegerType>(resultDTy)) !isa<mlir::FloatType, mlir::IntegerType>(resultDTy))
return op.emitError("unimplemented: non-fp not-int type"); return op.emitError("unimplemented: non-fp not-int type");
size_t inRank = input.getType().cast<RankedTensorType>().getRank(); size_t inRank = cast<RankedTensorType>(input.getType()).getRank();
size_t numSpatialDims = inRank - 2; size_t numSpatialDims = inRank - 2;
if (numSpatialDims < 1 || numSpatialDims > 3) if (numSpatialDims < 1 || numSpatialDims > 3)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1067,11 +1067,11 @@ public:
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0); rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
} else { } else {
auto biasType = bias.getType().cast<RankedTensorType>(); auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1) if (biasType.getRank() != 1)
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank(); auto resultRank = cast<RankedTensorType>(initTensor.getType()).getRank();
SmallVector<AffineMap> indexingMaps = { SmallVector<AffineMap> indexingMaps = {
// bias is used to initialize the channels - dimension 1 of output // bias is used to initialize the channels - dimension 1 of output
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0, AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
@ -1228,9 +1228,9 @@ public:
// Special depthwise case // Special depthwise case
auto inShape = makeShapeTorchCompatible( auto inShape = makeShapeTorchCompatible(
input.getType().cast<RankedTensorType>().getShape()); cast<RankedTensorType>(input.getType()).getShape());
auto weightShape = makeShapeTorchCompatible( auto weightShape = makeShapeTorchCompatible(
weight.getType().cast<RankedTensorType>().getShape()); cast<RankedTensorType>(weight.getType()).getShape());
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
// Collapse weight shape // Collapse weight shape
@ -1264,7 +1264,7 @@ public:
// Grouped case, use the grouped conv linalg op // Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) { auto expandGroups = [&](Value tensor, size_t dim) {
auto inType = tensor.getType().cast<RankedTensorType>(); auto inType = cast<RankedTensorType>(tensor.getType());
auto inShape = makeShapeTorchCompatible(inType.getShape()); auto inShape = makeShapeTorchCompatible(inType.getShape());
SmallVector<int64_t> outShape; SmallVector<int64_t> outShape;
@ -1297,7 +1297,7 @@ public:
// expand F,C,H,W -> G,F/G,C,H,W // expand F,C,H,W -> G,F/G,C,H,W
auto expandWeight = [&](Value tensor) { auto expandWeight = [&](Value tensor) {
auto inType = tensor.getType().cast<RankedTensorType>(); auto inType = cast<RankedTensorType>(tensor.getType());
auto inShape = makeShapeTorchCompatible(inType.getShape()); auto inShape = makeShapeTorchCompatible(inType.getShape());
SmallVector<int64_t> outShape{ SmallVector<int64_t> outShape{

View File

@ -80,7 +80,7 @@ computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter,
SmallVectorImpl<int64_t> &dilationInts, SmallVectorImpl<int64_t> &dilationInts,
SmallVectorImpl<Value> &kernelSizeIntValues, SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<Value> &outTensorShape, Value initValue) { SmallVectorImpl<Value> &outTensorShape, Value initValue) {
Type elementType = self.getType().cast<RankedTensorType>().getElementType(); Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
Location loc = op->getLoc(); Location loc = op->getLoc();
Value N = getDimOp(rewriter, loc, self, 0); Value N = getDimOp(rewriter, loc, self, 0);
@ -116,7 +116,7 @@ static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter,
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0}; SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
SmallVector<int64_t> highPaddingIncludingNC = {0, 0}; SmallVector<int64_t> highPaddingIncludingNC = {0, 0};
unsigned selfRank = self.getType().cast<RankedTensorType>().getRank(); unsigned selfRank = cast<RankedTensorType>(self.getType()).getRank();
unsigned paddingIntsSize = paddingInts.size(); unsigned paddingIntsSize = paddingInts.size();
if (paddingIntsSize == 2 * (selfRank - 2)) { if (paddingIntsSize == 2 * (selfRank - 2)) {
@ -153,7 +153,7 @@ static LogicalResult createPoolingOp(
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr, SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) { SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
Location loc = op->getLoc(); Location loc = op->getLoc();
Type elementType = self.getType().cast<RankedTensorType>().getElementType(); Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
if (!isa<mlir::FloatType>(elementType) && !supportNonFPInput) if (!isa<mlir::FloatType>(elementType) && !supportNonFPInput)
return op->emitError("unimplemented: non-floating point type"); return op->emitError("unimplemented: non-floating point type");
@ -214,7 +214,7 @@ private:
bool ceilMode) const { bool ceilMode) const {
SmallVector<Value, 5> outTensorShape; SmallVector<Value, 5> outTensorShape;
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
Type elementType = self.getType().cast<RankedTensorType>().getElementType(); Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType, elementType,
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(), APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
@ -307,7 +307,7 @@ public:
const TypeConverter *typeConverter = this->getTypeConverter(); const TypeConverter *typeConverter = this->getTypeConverter();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
int64_t selfRank = self.getType().cast<RankedTensorType>().getRank(); int64_t selfRank = cast<RankedTensorType>(self.getType()).getRank();
if (selfRank != Dim + 2) if (selfRank != Dim + 2)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -326,7 +326,7 @@ public:
strideInts, paddingInts))) strideInts, paddingInts)))
return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
Type elementType = self.getType().cast<RankedTensorType>().getElementType(); Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
if constexpr (Dim == 2) { if constexpr (Dim == 2) {
SmallVector<Value, 4> outTensorShape; SmallVector<Value, 4> outTensorShape;
@ -389,7 +389,7 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
const TypeConverter *typeConverter = getTypeConverter(); const TypeConverter *typeConverter = getTypeConverter();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
RankedTensorType selfType = self.getType().cast<RankedTensorType>(); RankedTensorType selfType = cast<RankedTensorType>(self.getType());
Type elementType = selfType.getElementType(); Type elementType = selfType.getElementType();
RankedTensorType indicesRankedTensorType = RankedTensorType indicesRankedTensorType =
getTypeConverter() getTypeConverter()
@ -552,7 +552,7 @@ public:
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
Type inputElementType = Type inputElementType =
self.getType().cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(self.getType()).getElementType();
Type resultType = typeConverter->convertType(op.getType()); Type resultType = typeConverter->convertType(op.getType());
Type resultElementType = Type resultElementType =
cast<RankedTensorType>(resultType).getElementType(); cast<RankedTensorType>(resultType).getElementType();
@ -592,8 +592,7 @@ public:
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) { if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
Value kHtimeskW = rewriter.create<arith::MulIOp>( Value kHtimeskW = rewriter.create<arith::MulIOp>(
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
divisor = divisor = isa<Torch::NoneType>(op.getDivisorOverride().getType())
op.getDivisorOverride().getType().template isa<Torch::NoneType>()
? kHtimeskW ? kHtimeskW
: adaptor.getDivisorOverride(); : adaptor.getDivisorOverride();
} else { } else {
@ -901,7 +900,7 @@ public:
const TypeConverter *typeConverter = this->getTypeConverter(); const TypeConverter *typeConverter = this->getTypeConverter();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType = cast<RankedTensorType>(input.getType());
const Type elementType = inputType.getElementType(); const Type elementType = inputType.getElementType();
// get rank of input (same as rank of output) // get rank of input (same as rank of output)

View File

@ -127,7 +127,7 @@ public:
Value from = adaptor.getFrom(); Value from = adaptor.getFrom();
Value to = adaptor.getTo(); Value to = adaptor.getTo();
Value generator = adaptor.getGenerator(); Value generator = adaptor.getGenerator();
RankedTensorType resultType = self.getType().cast<RankedTensorType>(); RankedTensorType resultType = cast<RankedTensorType>(self.getType());
Type elemTy = resultType.getElementType(); Type elemTy = resultType.getElementType();
Type f64Ty = rewriter.getF64Type(); Type f64Ty = rewriter.getF64Type();

View File

@ -66,8 +66,7 @@ public:
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType())); cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
auto idxResultType = auto idxResultType =
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType())); cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
RankedTensorType inputType = RankedTensorType inputType = cast<RankedTensorType>(input.getType());
input.getType().template cast<RankedTensorType>();
Type idxElementType = Type idxElementType =
getElementTypeOrSelf(typec->convertType(idxResultType)); getElementTypeOrSelf(typec->convertType(idxResultType));
if (!isa<IntegerType>(idxElementType)) if (!isa<IntegerType>(idxElementType))
@ -472,7 +471,7 @@ private:
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
typename T::Adaptor adaptor(operands); typename T::Adaptor adaptor(operands);
opInfo.tensorOperand = adaptor.getSelf(); opInfo.tensorOperand = adaptor.getSelf();
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&opInfo.keepDim))) if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&opInfo.keepDim)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -480,8 +479,7 @@ private:
SmallVector<int64_t> dimList; SmallVector<int64_t> dimList;
int64_t dim; int64_t dim;
bool isNoneOrEmptyDimList = bool isNoneOrEmptyDimList = isa<Torch::NoneType>(op.getDim().getType());
op.getDim().getType().template isa<Torch::NoneType>();
if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
// Fix negative dimensions, if any, before adding to the list. // Fix negative dimensions, if any, before adding to the list.
for (int64_t dim : dimList) { for (int64_t dim : dimList) {
@ -522,7 +520,7 @@ private:
if (isa<AtenAnyOp, AtenAllOp, AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp, if (isa<AtenAnyOp, AtenAllOp, AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp,
AtenNormScalarOp>(op)) { AtenNormScalarOp>(op)) {
opInfo.tensorOperand = operands[0]; opInfo.tensorOperand = operands[0];
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
// `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and // `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and
// `AtenMinOp` each reduce along all the dimensions of the input tensor. // `AtenMinOp` each reduce along all the dimensions of the input tensor.

View File

@ -42,7 +42,7 @@ public:
return failure(); return failure();
Location loc = op->getLoc(); Location loc = op->getLoc();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto type = self.getType().cast<RankedTensorType>(); auto type = cast<RankedTensorType>(self.getType());
int64_t rank = type.getRank(); int64_t rank = type.getRank();
auto primList = op.getPad().getDefiningOp<Torch::PrimListConstructOp>(); auto primList = op.getPad().getDefiningOp<Torch::PrimListConstructOp>();
@ -105,7 +105,7 @@ public:
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType); convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
Type padType = tensor::PadOp::inferResultType( Type padType = tensor::PadOp::inferResultType(
self.getType().cast<RankedTensorType>(), staticLow, staticHigh); cast<RankedTensorType>(self.getType()), staticLow, staticHigh);
Value paddedInput = rewriter.create<tensor::PadOp>( Value paddedInput = rewriter.create<tensor::PadOp>(
loc, padType, self, lowPad, highPad, castedValue); loc, padType, self, lowPad, highPad, castedValue);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
@ -354,7 +354,7 @@ public:
// The pin_memory should be either `False` or `none`. // The pin_memory should be either `False` or `none`.
bool pinMemory; bool pinMemory;
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() && if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) { pinMemory)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -376,7 +376,7 @@ public:
auto resultType = typeConverter->convertType(op.getType()) auto resultType = typeConverter->convertType(op.getType())
.template cast<RankedTensorType>(); .template cast<RankedTensorType>();
Type resultElementType; Type resultElementType;
if (op.getDtype().getType().template isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(op.getDtype().getType())) {
resultElementType = resultType.getElementType(); resultElementType = resultType.getElementType();
} else { } else {
int64_t dtypeInt; int64_t dtypeInt;
@ -423,7 +423,7 @@ public:
// The pin_memory should be either `False` or `none`. // The pin_memory should be either `False` or `none`.
bool pinMemory; bool pinMemory;
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() && if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) pinMemory))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -480,7 +480,7 @@ public:
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
auto resultType = auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
Type resultElementType; Type resultElementType;
if (op.getDtype().getType().isa<Torch::NoneType>()) { if (op.getDtype().getType().isa<Torch::NoneType>()) {
resultElementType = getDefaultDtypeForTorchScalar( resultElementType = getDefaultDtypeForTorchScalar(

View File

@ -38,7 +38,7 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
Value dim = adaptor.getDim(); Value dim = adaptor.getDim();
auto type = self.getType().cast<RankedTensorType>(); auto type = cast<RankedTensorType>(self.getType());
Value inputRank = rewriter.create<arith::ConstantOp>( Value inputRank = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(type.getRank())); loc, rewriter.getI64IntegerAttr(type.getRank()));
Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank); Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank);
@ -86,8 +86,7 @@ public:
Value input = adaptor.getA(); Value input = adaptor.getA();
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input); SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
int64_t inputRank = inputSizes.size(); int64_t inputRank = inputSizes.size();
Type inputDtype = Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
op.getA().getType().template cast<BaseTensorType>().getDtype();
// The `input` tensor must contain exactly one element, i.e., either the // The `input` tensor must contain exactly one element, i.e., either the
// `input` is a zero rank tensor or all the dimensions of the `input` tensor // `input` is a zero rank tensor or all the dimensions of the `input` tensor

View File

@ -34,7 +34,7 @@ using namespace mlir::torch::Torch;
// Check if a ranked-tensor has the specified element type. // Check if a ranked-tensor has the specified element type.
template <typename elementType> static bool hasElementType(Value tensor) { template <typename elementType> static bool hasElementType(Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>(); auto tensorType = cast<RankedTensorType>(tensor.getType());
Type tensorElementType = tensorType.getElementType(); Type tensorElementType = tensorType.getElementType();
return isa<elementType>(tensorElementType); return isa<elementType>(tensorElementType);
} }
@ -173,8 +173,7 @@ static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
return nullptr; return nullptr;
} }
Type elementalType = Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
op.getSelf().getType().template cast<BaseTensorType>().getDtype();
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) { if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs); return createLessThan(b, loc, elementalType, lhs, rhs);
} }
@ -200,7 +199,7 @@ template <arith::CmpIPredicate predicate>
static LogicalResult static LogicalResult
createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs, createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
Operation *op, ArrayRef<Value> operands, Value &result) { Operation *op, ArrayRef<Value> operands, Value &result) {
auto inputType = operands[0].getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(operands[0].getType());
uint64_t inputRank = inputType.getRank(); uint64_t inputRank = inputType.getRank();
// Use the indices of the two innermost dimensions. // Use the indices of the two innermost dimensions.
@ -405,7 +404,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr; return nullptr;
} }
Type resultElementType = Type resultElementType =
bitwiseAndScalar.getType().cast<BaseTensorType>().getDtype(); cast<BaseTensorType>(bitwiseAndScalar.getType()).getDtype();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype, Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt, /*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType); /*dstOriginalDtype=*/resultElementType);
@ -537,7 +536,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (auto relu = dyn_cast<AtenReluOp>(op)) { if (auto relu = dyn_cast<AtenReluOp>(op)) {
Value zeroPoint = getZeroPoint(relu.getSelf()); Value zeroPoint = getZeroPoint(relu.getSelf());
Value arg = payloadArgs[0]; Value arg = payloadArgs[0];
auto intType = arg.getType().dyn_cast<mlir::IntegerType>(); auto intType = dyn_cast<mlir::IntegerType>(arg.getType());
if (zeroPoint && !intType) { if (zeroPoint && !intType) {
relu.emitError("unimplemented: non-integer quantized Relu."); relu.emitError("unimplemented: non-integer quantized Relu.");
return nullptr; return nullptr;
@ -739,9 +738,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto add = dyn_cast<AtenAddTensorOp>(op)) { if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
AtenAddTensorOp::Adaptor adaptor(operands); AtenAddTensorOp::Adaptor adaptor(operands);
Type resultElementType = add.getType().cast<BaseTensorType>().getDtype(); Type resultElementType = cast<BaseTensorType>(add.getType()).getDtype();
Type dtype = converter->convertType(add.getType()) Type dtype = cast<RankedTensorType>(converter->convertType(add.getType()))
.cast<RankedTensorType>()
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype, Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt, /*srcOriginalDtype=*/std::nullopt,
@ -762,10 +760,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto sub = dyn_cast<AtenSubTensorOp>(op)) { if (auto sub = dyn_cast<AtenSubTensorOp>(op)) {
AtenSubTensorOp::Adaptor adaptor(operands); AtenSubTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(sub.getType()) Type dtype = cast<RankedTensorType>(converter->convertType(sub.getType()))
.cast<RankedTensorType>()
.getElementType(); .getElementType();
Type resultElementType = sub.getType().cast<BaseTensorType>().getDtype(); Type resultElementType = cast<BaseTensorType>(sub.getType()).getDtype();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype, Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt, /*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType); /*dstOriginalDtype=*/resultElementType);
@ -785,8 +782,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
} }
if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) { if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
Type dtype = converter->convertType(subScalar.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(subScalar.getType()))
.getElementType(); .getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype);
@ -805,11 +802,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr; return nullptr;
} }
if (auto addScalar = dyn_cast<AtenAddScalarOp>(op)) { if (auto addScalar = dyn_cast<AtenAddScalarOp>(op)) {
Type dtype = converter->convertType(addScalar.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(addScalar.getType()))
.getElementType(); .getElementType();
Type resultElementType = Type resultElementType =
addScalar.getType().cast<BaseTensorType>().getDtype(); cast<BaseTensorType>(addScalar.getType()).getDtype();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype, Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt, /*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType); /*dstOriginalDtype=*/resultElementType);
@ -832,8 +829,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) { if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
AtenMulTensorOp::Adaptor adaptor(operands); AtenMulTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(mul.getType()) Type dtype = cast<RankedTensorType>(converter->convertType(mul.getType()))
.cast<RankedTensorType>()
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
@ -846,8 +842,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
} }
if (auto atan2 = dyn_cast<AtenAtan2Op>(op)) { if (auto atan2 = dyn_cast<AtenAtan2Op>(op)) {
Type dtype = converter->convertType(atan2.getType()) Type dtype = cast<RankedTensorType>(converter->convertType(atan2.getType()))
.cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!isa<mlir::FloatType>(dtype)) { if (!isa<mlir::FloatType>(dtype)) {
atan2.emitError("Atan2 requires floating point result type"); atan2.emitError("Atan2 requires floating point result type");
@ -883,8 +878,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto div = dyn_cast<AtenDivTensorOp>(op)) { if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands); AtenDivTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(div.getType()) Type dtype = cast<RankedTensorType>(converter->convertType(div.getType()))
.cast<RankedTensorType>()
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
@ -907,7 +901,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
operands); operands);
} }
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) { if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
Type dtype = pow.getType().cast<ValueTensorType>().getDtype(); Type dtype = cast<ValueTensorType>(pow.getType()).getDtype();
if (!isa<mlir::FloatType>(dtype)) { if (!isa<mlir::FloatType>(dtype)) {
pow.emitError("unimplemented: non-floating point dtype"); pow.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
@ -925,14 +919,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
pow.emitError("unimplemented: non-floating point dtype"); pow.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
Type dtype = pow.getSelf().getType().cast<ValueTensorType>().getDtype(); Type dtype = cast<ValueTensorType>(pow.getSelf().getType()).getDtype();
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype); Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted); return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
} }
if (auto pow = dyn_cast<AtenPowTensorTensorOp>(op)) { if (auto pow = dyn_cast<AtenPowTensorTensorOp>(op)) {
Type dtype = converter->convertType(pow.getType()) Type dtype = cast<RankedTensorType>(converter->convertType(pow.getType()))
.cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!isa<mlir::FloatType>(dtype)) { if (!isa<mlir::FloatType>(dtype)) {
pow.emitError("unimplemented: non-floating point dtype"); pow.emitError("unimplemented: non-floating point dtype");
@ -944,8 +937,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto imag = dyn_cast<AtenImagOp>(op)) { if (auto imag = dyn_cast<AtenImagOp>(op)) {
Type dtype = converter->convertType(imag.getType()) Type dtype = cast<RankedTensorType>(converter->convertType(imag.getType()))
.cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!isa<mlir::FloatType>(dtype)) { if (!isa<mlir::FloatType>(dtype)) {
imag.emitError("unimplemented: non-floating point dtype"); imag.emitError("unimplemented: non-floating point dtype");
@ -956,8 +948,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto real = dyn_cast<AtenRealOp>(op)) { if (auto real = dyn_cast<AtenRealOp>(op)) {
Type dtype = converter->convertType(real.getType()) Type dtype = cast<RankedTensorType>(converter->convertType(real.getType()))
.cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!isa<mlir::FloatType>(dtype)) { if (!isa<mlir::FloatType>(dtype)) {
real.emitError("unimplemented: non-floating point dtype"); real.emitError("unimplemented: non-floating point dtype");
@ -968,7 +959,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) { if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
Type dtype = gtScalar.getSelf().getType().cast<BaseTensorType>().getDtype(); Type dtype = cast<BaseTensorType>(gtScalar.getSelf().getType()).getDtype();
// TODO: `gtTensor` and `gtScalar` share similar code and can be called from // TODO: `gtTensor` and `gtScalar` share similar code and can be called from
// one static function. // one static function.
@ -998,7 +989,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) { if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
Type dtype = geScalar.getSelf().getType().cast<BaseTensorType>().getDtype(); Type dtype = cast<BaseTensorType>(geScalar.getSelf().getType()).getDtype();
// TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that // TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that
// can be refactored. // can be refactored.
@ -1028,7 +1019,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) { if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
Type dtype = eqScalar.getSelf().getType().cast<BaseTensorType>().getDtype(); Type dtype = cast<BaseTensorType>(eqScalar.getSelf().getType()).getDtype();
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
@ -1044,7 +1035,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) { if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
Type dtype = neScalar.getSelf().getType().cast<BaseTensorType>().getDtype(); Type dtype = cast<BaseTensorType>(neScalar.getSelf().getType()).getDtype();
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
@ -1060,7 +1051,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) { if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
Type dtype = ltScalar.getSelf().getType().cast<BaseTensorType>().getDtype(); Type dtype = cast<BaseTensorType>(ltScalar.getSelf().getType()).getDtype();
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
@ -1088,7 +1079,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) { if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
Type dtype = leScalar.getSelf().getType().cast<BaseTensorType>().getDtype(); Type dtype = cast<BaseTensorType>(leScalar.getSelf().getType()).getDtype();
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
@ -1116,8 +1107,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) { if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
Type dtype = converter->convertType(whereSelf.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(whereSelf.getType()))
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
@ -1141,7 +1132,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::AddFOp>(loc, start, weightedDelta); return b.create<arith::AddFOp>(loc, start, weightedDelta);
} }
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) { if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
Type dtype = minimum.getType().cast<BaseTensorType>().getDtype(); Type dtype = cast<BaseTensorType>(minimum.getType()).getDtype();
Type elemTy = converter->convertType(minimum.getType()) Type elemTy = converter->convertType(minimum.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
@ -1151,7 +1142,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::SelectOp>(loc, pred, lhs, rhs); return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
} }
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) { if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
Type dtype = maximum.getType().cast<BaseTensorType>().getDtype(); Type dtype = cast<BaseTensorType>(maximum.getType()).getDtype();
Type elemTy = converter->convertType(maximum.getType()) Type elemTy = converter->convertType(maximum.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
@ -1170,15 +1161,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr; return nullptr;
} }
Type dtype = converter->convertType(clamp.getType()) Type dtype = cast<RankedTensorType>(converter->convertType(clamp.getType()))
.cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!isa<mlir::FloatType, mlir::IntegerType>(dtype)) { if (!isa<mlir::FloatType, mlir::IntegerType>(dtype)) {
clamp.emitError("unimplement type for clamp"); clamp.emitError("unimplement type for clamp");
return nullptr; return nullptr;
} }
Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype(); Type dstOriginalDtype = cast<BaseTensorType>(clamp.getType()).getDtype();
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype); bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
if (auto intTy = dyn_cast<IntegerType>(dstOriginalDtype)) { if (auto intTy = dyn_cast<IntegerType>(dstOriginalDtype)) {
isUnsigned = intTy.isUnsigned(); isUnsigned = intTy.isUnsigned();
@ -1219,8 +1209,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
clampTensor.emitError("unimplemented: runtime optional type"); clampTensor.emitError("unimplemented: runtime optional type");
return nullptr; return nullptr;
} }
Type dtype = converter->convertType(clampTensor.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(clampTensor.getType()))
.getElementType(); .getElementType();
bool isMinNone = true; bool isMinNone = true;
auto result = payloadArgs[0]; auto result = payloadArgs[0];
@ -1263,8 +1253,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return result; return result;
} }
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) { if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
Type dtype = converter->convertType(rsub.getType()) Type dtype = cast<RankedTensorType>(converter->convertType(rsub.getType()))
.cast<RankedTensorType>()
.getElementType(); .getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype);
@ -1283,8 +1272,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr; return nullptr;
} }
if (auto mulScalar = dyn_cast<AtenMulScalarOp>(op)) { if (auto mulScalar = dyn_cast<AtenMulScalarOp>(op)) {
Type dtype = converter->convertType(mulScalar.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(mulScalar.getType()))
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype); Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
@ -1297,8 +1286,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) { if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
Value input = payloadArgs[0]; Value input = payloadArgs[0];
Type dtype = converter->convertType(atenToDtype.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(atenToDtype.getType()))
.getElementType(); .getElementType();
Type resultElementType; Type resultElementType;
int64_t dtypeInt; int64_t dtypeInt;
@ -1320,8 +1309,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return result; return result;
} }
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) { if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
Type dtype = converter->convertType(divScalar.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(divScalar.getType()))
.getElementType(); .getElementType();
if (!isa<mlir::FloatType>(dtype)) { if (!isa<mlir::FloatType>(dtype)) {
divScalar.emitError("unimplemented: non-floating point dtype"); divScalar.emitError("unimplemented: non-floating point dtype");
@ -1395,8 +1384,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return result; return result;
} }
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) { if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
Type dtype = converter->convertType(reciprocal.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(reciprocal.getType()))
.getElementType(); .getElementType();
Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Type elementType = arg.getType(); Type elementType = arg.getType();
@ -1416,8 +1405,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
// The approach used here is as follows: // The approach used here is as follows:
// result = self <= threshold ? value : self // result = self <= threshold ? value : self
AtenThresholdOp::Adaptor adaptor(operands); AtenThresholdOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(thresholdOp.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(thresholdOp.getType()))
.getElementType(); .getElementType();
Value self = payloadArgs[0]; Value self = payloadArgs[0];
@ -1438,8 +1427,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
// The approach used here is as follows: // The approach used here is as follows:
// result = self <= threshold ? 0 : grad // result = self <= threshold ? 0 : grad
AtenThresholdBackwardOp::Adaptor adaptor(operands); AtenThresholdBackwardOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(thresholdBackward.getType()) Type dtype = cast<RankedTensorType>(
.cast<RankedTensorType>() converter->convertType(thresholdBackward.getType()))
.getElementType(); .getElementType();
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
@ -1459,15 +1448,15 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto fillScalar = dyn_cast<AtenFillScalarOp>(op)) { if (auto fillScalar = dyn_cast<AtenFillScalarOp>(op)) {
AtenFillScalarOp::Adaptor adaptor(operands); AtenFillScalarOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(fillScalar.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(fillScalar.getType()))
.getElementType(); .getElementType();
return convertScalarToDtype(b, loc, adaptor.getValue(), dtype); return convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
} }
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) { if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
AtenMaskedFillScalarOp::Adaptor adaptor(operands); AtenMaskedFillScalarOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(maskedFillTensor.getType()) Type dtype = cast<RankedTensorType>(
.cast<RankedTensorType>() converter->convertType(maskedFillTensor.getType()))
.getElementType(); .getElementType();
Value input = payloadArgs[0]; Value input = payloadArgs[0];
@ -1477,8 +1466,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto fillTensor = dyn_cast<AtenFillTensorOp>(op)) { if (auto fillTensor = dyn_cast<AtenFillTensorOp>(op)) {
AtenFillTensorOp::Adaptor adaptor(operands); AtenFillTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(fillTensor.getType()) Type dtype =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(fillTensor.getType()))
.getElementType(); .getElementType();
return convertScalarToDtype(b, loc, payloadArgs[1], dtype); return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
} }
@ -1519,7 +1508,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
auto value = payloadArgs[0]; auto value = payloadArgs[0];
auto valueTy = value.getType(); auto valueTy = value.getType();
auto qtensor = op->getOperand(0); auto qtensor = op->getOperand(0);
auto qtensorTy = qtensor.getType().cast<ValueTensorType>().getDtype(); auto qtensorTy = cast<ValueTensorType>(qtensor.getType()).getDtype();
Value zp, scale; Value zp, scale;
if (auto makeQTensor = if (auto makeQTensor =
@ -1744,8 +1733,8 @@ public:
Value ignoreIndex = adaptor.getIgnoreIndex(); Value ignoreIndex = adaptor.getIgnoreIndex();
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex); Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank(); unsigned inputRank = cast<RankedTensorType>(input.getType()).getRank();
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank(); unsigned targetRank = cast<RankedTensorType>(target.getType()).getRank();
// TODO: Add support for k-dim loss. // TODO: Add support for k-dim loss.
if (inputRank > 2) { if (inputRank > 2) {
@ -1931,11 +1920,11 @@ public:
failed(checkNotNone(rewriter, op, runningVar))) failed(checkNotNone(rewriter, op, runningVar)))
return failure(); return failure();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto weightType = weight.getType().cast<RankedTensorType>(); auto weightType = cast<RankedTensorType>(weight.getType());
auto biasType = bias.getType().cast<RankedTensorType>(); auto biasType = cast<RankedTensorType>(bias.getType());
auto runningMeanType = runningMean.getType().cast<RankedTensorType>(); auto runningMeanType = cast<RankedTensorType>(runningMean.getType());
auto runningVarType = runningVar.getType().cast<RankedTensorType>(); auto runningVarType = cast<RankedTensorType>(runningVar.getType());
auto inputRank = inputType.getRank(); auto inputRank = inputType.getRank();
if (inputRank < 2) if (inputRank < 2)
@ -2032,9 +2021,9 @@ public:
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex()); Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex());
Value totalWeight = adaptor.getTotalWeight(); Value totalWeight = adaptor.getTotalWeight();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
int inputRank = inputType.getRank(); int inputRank = inputType.getRank();
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>(); auto gradOutputType = cast<RankedTensorType>(gradOutput.getType());
Type resultElementType = gradOutputType.getElementType(); Type resultElementType = gradOutputType.getElementType();
int64_t reduction; int64_t reduction;
@ -2059,7 +2048,7 @@ public:
createZeroInitTensor(rewriter, loc, outputSize, resultElementType); createZeroInitTensor(rewriter, loc, outputSize, resultElementType);
auto getAffineMapForSingleElementTensor = [&](Value tensor) { auto getAffineMapForSingleElementTensor = [&](Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>(); auto tensorType = cast<RankedTensorType>(tensor.getType());
SmallVector<AffineExpr> affineExprs(tensorType.getRank(), SmallVector<AffineExpr> affineExprs(tensorType.getRank(),
rewriter.getAffineConstantExpr(0)); rewriter.getAffineConstantExpr(0));
return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs, return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs,
@ -2188,12 +2177,12 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
auto aRankedTensorType = adaptor.getA().getType().cast<RankedTensorType>(); auto aRankedTensorType = cast<RankedTensorType>(adaptor.getA().getType());
const TypeConverter *typeConverter = getTypeConverter(); const TypeConverter *typeConverter = getTypeConverter();
auto resultRankedTensorType = auto resultRankedTensorType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
// The dimension being split must be statically known. // The dimension being split must be statically known.
@ -2233,11 +2222,11 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
auto aRankedTensorType = adaptor.getA().getType().cast<RankedTensorType>(); auto aRankedTensorType = cast<RankedTensorType>(adaptor.getA().getType());
const TypeConverter *typeConverter = getTypeConverter(); const TypeConverter *typeConverter = getTypeConverter();
auto resultRankedTensorType = auto resultRankedTensorType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
// Collapse range must be statically known. // Collapse range must be statically known.
int64_t startInt; int64_t startInt;
@ -2328,7 +2317,7 @@ public:
return failure(); return failure();
} }
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto inputElementType = inputType.getElementType(); auto inputElementType = inputType.getElementType();
if (!isa<mlir::FloatType>(inputElementType)) { if (!isa<mlir::FloatType>(inputElementType)) {
@ -2433,8 +2422,8 @@ public:
return failure(); return failure();
} }
auto operandDTy = operand.getType().cast<ValueTensorType>().getDtype(); auto operandDTy = cast<ValueTensorType>(operand.getType()).getDtype();
auto zeropointDTy = zeropoint.getType().cast<ValueTensorType>().getDtype(); auto zeropointDTy = cast<ValueTensorType>(zeropoint.getType()).getDtype();
operand = converter->materializeTargetConversion( operand = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(operand.getType()), operand); rewriter, loc, converter->convertType(operand.getType()), operand);
scale = converter->materializeTargetConversion( scale = converter->materializeTargetConversion(
@ -2537,7 +2526,7 @@ public:
Value twoFloat = rewriter.create<arith::ConstantOp>( Value twoFloat = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 2.0)); loc, rewriter.getFloatAttr(floatType, 2.0));
Value input = adaptor.getInput(); Value input = adaptor.getInput();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto inputShape = inputType.getShape(); auto inputShape = inputType.getShape();
Value innerDim0a = rewriter.create<tensor::DimOp>(loc, input, 2); Value innerDim0a = rewriter.create<tensor::DimOp>(loc, input, 2);
Value innerDim1a = rewriter.create<tensor::DimOp>(loc, input, 3); Value innerDim1a = rewriter.create<tensor::DimOp>(loc, input, 3);
@ -2558,7 +2547,7 @@ public:
Value innerDim1e = Value innerDim1e =
rewriter.create<arith::DivFOp>(loc, innerDim1d, twoFloat); rewriter.create<arith::DivFOp>(loc, innerDim1d, twoFloat);
Value grid = adaptor.getGrid(); Value grid = adaptor.getGrid();
auto gridType = grid.getType().cast<RankedTensorType>(); auto gridType = cast<RankedTensorType>(grid.getType());
auto gridShape = gridType.getShape(); auto gridShape = gridType.getShape();
auto gridRank = gridType.getRank(); auto gridRank = gridType.getRank();
SmallVector<Value> extractGridOffsets0(gridRank, zeroIndex); SmallVector<Value> extractGridOffsets0(gridRank, zeroIndex);

View File

@ -37,9 +37,8 @@ Value torch_to_linalg::getPaddedTensor(
SmallVectorImpl<int64_t> &lowPaddingInts, SmallVectorImpl<int64_t> &lowPaddingInts,
SmallVectorImpl<int64_t> &highPaddingInts, Value pad) { SmallVectorImpl<int64_t> &highPaddingInts, Value pad) {
Location loc = op->getLoc(); Location loc = op->getLoc();
Type rankedTensorType = Type rankedTensorType = tensor::PadOp::inferResultType(
tensor::PadOp::inferResultType(input.getType().cast<RankedTensorType>(), cast<RankedTensorType>(input.getType()), lowPaddingInts, highPaddingInts);
lowPaddingInts, highPaddingInts);
SmallVector<OpFoldResult> lowPaddings = SmallVector<OpFoldResult> lowPaddings =
getIndexIntsAsOpFoldResult(b, lowPaddingInts); getIndexIntsAsOpFoldResult(b, lowPaddingInts);
SmallVector<OpFoldResult> highPaddings = SmallVector<OpFoldResult> highPaddings =
@ -61,7 +60,7 @@ Value torch_to_linalg::getZeroPaddedTensor(
Location loc = op->getLoc(); Location loc = op->getLoc();
Value c0 = b.create<arith::ConstantOp>( Value c0 = b.create<arith::ConstantOp>(
loc, loc,
b.getZeroAttr(input.getType().cast<RankedTensorType>().getElementType())); b.getZeroAttr(cast<RankedTensorType>(input.getType()).getElementType()));
return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0); return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0);
} }
@ -73,7 +72,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
int unpaddedDims, Value pad) { int unpaddedDims, Value pad) {
assert(input.getType().isa<RankedTensorType>() && assert(input.getType().isa<RankedTensorType>() &&
"input must be RankedTensorType"); "input must be RankedTensorType");
unsigned int inRank = input.getType().cast<RankedTensorType>().getRank(); unsigned int inRank = cast<RankedTensorType>(input.getType()).getRank();
Location loc = op->getLoc(); Location loc = op->getLoc();
SmallVector<Value> inputDims = getTensorSizes(b, loc, input); SmallVector<Value> inputDims = getTensorSizes(b, loc, input);
@ -86,7 +85,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
pad < paddingIncludingUnchanged.end(); pad++) pad < paddingIncludingUnchanged.end(); pad++)
*pad = castIntToIndex(b, loc, *pad); *pad = castIntToIndex(b, loc, *pad);
Type elementType = input.getType().cast<RankedTensorType>().getElementType(); Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
// TODO: audit possibility of sparsity on this tensor // TODO: audit possibility of sparsity on this tensor
Type inputType = Type inputType =
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>( RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
@ -158,7 +157,7 @@ Value torch_to_linalg::getOutputDimForConvTransposeOps(
Value torch_to_linalg::createReductionLinalgGeneric( Value torch_to_linalg::createReductionLinalgGeneric(
OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem, OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
// Get the result shape by obtaining the size of each // Get the result shape by obtaining the size of each
// dimension in the input tensor that is not getting reduced. // dimension in the input tensor that is not getting reduced.
@ -237,7 +236,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
SmallVector<int64_t> operandRanks; SmallVector<int64_t> operandRanks;
operandRanks.resize(tensorOperands.size()); operandRanks.resize(tensorOperands.size());
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) { llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
return tensor.getType().dyn_cast<RankedTensorType>().getRank(); return dyn_cast<RankedTensorType>(tensor.getType()).getRank();
}); });
auto resultRankIt = auto resultRankIt =
@ -253,7 +252,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b); bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b);
for (Value tensorOperand : tensorOperands) { for (Value tensorOperand : tensorOperands) {
SmallVector<AffineExpr> exprs; SmallVector<AffineExpr> exprs;
auto type = tensorOperand.getType().cast<RankedTensorType>(); auto type = cast<RankedTensorType>(tensorOperand.getType());
for (auto size : for (auto size :
llvm::enumerate(makeShapeTorchCompatible(type.getShape()))) { llvm::enumerate(makeShapeTorchCompatible(type.getShape()))) {
// If the size is statically known to be 1, we don't want any // If the size is statically known to be 1, we don't want any
@ -327,7 +326,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
Operation *op, PatternRewriter &rewriter, Value input, Operation *op, PatternRewriter &rewriter, Value input,
SmallVector<Value> broadcastToShape, RankedTensorType broadcastType, SmallVector<Value> broadcastToShape, RankedTensorType broadcastType,
Value &result, SmallVector<bool> useBroadcastToShape) { Value &result, SmallVector<bool> useBroadcastToShape) {
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank(); int64_t inputRank = inputType.getRank();
int64_t outputRank = broadcastToShape.size(); int64_t outputRank = broadcastToShape.size();
ArrayRef<int64_t> outputShape = broadcastType.getShape(); ArrayRef<int64_t> outputShape = broadcastType.getShape();
@ -525,7 +524,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc, Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
Value tensor) { Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>(); auto tensorType = cast<RankedTensorType>(tensor.getType());
auto rank = tensorType.getRank(); auto rank = tensorType.getRank();
SmallVector<int64_t> unknownSizes(rank, kUnknownSize); SmallVector<int64_t> unknownSizes(rank, kUnknownSize);
return b.create<tensor::CastOp>( return b.create<tensor::CastOp>(

View File

@ -66,8 +66,8 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other, mlir::Value &self, mlir::Value &other,
size_t dimSizeIndexBits) { size_t dimSizeIndexBits) {
auto selfTy = self.getType().template dyn_cast<RankedTensorType>(); auto selfTy = dyn_cast<RankedTensorType>(self.getType());
auto otherTy = other.getType().template dyn_cast<RankedTensorType>(); auto otherTy = dyn_cast<RankedTensorType>(other.getType());
auto selfRank = selfTy.getRank(); auto selfRank = selfTy.getRank();
auto otherRank = otherTy.getRank(); auto otherRank = otherTy.getRank();
if (selfRank == 0 || otherRank == 0) if (selfRank == 0 || otherRank == 0)
@ -171,7 +171,7 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfType = self.getType().cast<TensorType>(); auto selfType = cast<TensorType>(self.getType());
if (!selfType) { if (!selfType) {
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
} }
@ -197,12 +197,12 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = cast<TensorType>(self.getType());
if (!selfTy) if (!selfTy)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
if (selfTy.getElementType().isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<StablehloOpT>( rewriter.replaceOpWithNewOp<StablehloOpT>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
@ -229,14 +229,14 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = cast<TensorType>(self.getType());
if (!selfTy) if (!selfTy)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter() auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template cast<TensorType>(); .template cast<TensorType>();
if (resultTy.getElementType().template isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(resultTy.getElementType())) {
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy); Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src); rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
return success(); return success();
@ -304,8 +304,7 @@ public:
LogicalResult LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto inputType = auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType());
adaptor.getA().getType().template dyn_cast<RankedTensorType>();
if (!inputType) if (!inputType)
op.emitError("only Tensor types supported in StableHLO"); op.emitError("only Tensor types supported in StableHLO");
@ -313,8 +312,7 @@ public:
Value input = adaptor.getA(); Value input = adaptor.getA();
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input); SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
int64_t inputRank = inputSizes.size(); int64_t inputRank = inputSizes.size();
Type inputDtype = Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
op.getA().getType().template cast<BaseTensorType>().getDtype();
Value constantOne = Value constantOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1)); rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
@ -345,9 +343,9 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
auto lhsTy = lhs.getType().cast<TensorType>(); auto lhsTy = cast<TensorType>(lhs.getType());
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
auto rhsTy = rhs.getType().cast<TensorType>(); auto rhsTy = cast<TensorType>(rhs.getType());
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError("only Tensor types supported"); return op.emitError("only Tensor types supported");
@ -378,9 +376,9 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>(); RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs.getType());
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>(); RankedTensorType rhsType = dyn_cast<RankedTensorType>(rhs.getType());
if (!lhsType) if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
@ -433,9 +431,9 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
auto lhsType = lhs.getType().dyn_cast<TensorType>(); auto lhsType = dyn_cast<TensorType>(lhs.getType());
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
TensorType rhsType = rhs.getType().dyn_cast<TensorType>(); TensorType rhsType = dyn_cast<TensorType>(rhs.getType());
if (!lhsType) if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
@ -527,8 +525,8 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
RankedTensorType lhsTy = lhs.getType().dyn_cast<RankedTensorType>(); RankedTensorType lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>(); RankedTensorType rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
if (!lhsTy) if (!lhsTy)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
@ -616,8 +614,8 @@ public:
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
RankedTensorType lhsTy = lhs.getType().dyn_cast<RankedTensorType>(); RankedTensorType lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>(); RankedTensorType rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
if (!lhsTy) if (!lhsTy)
return op.emitError("lhs must be a ranked tensor type"); return op.emitError("lhs must be a ranked tensor type");
@ -659,11 +657,10 @@ public:
return rewriter.notifyMatchFailure(op, "dim1 must be constant"); return rewriter.notifyMatchFailure(op, "dim1 must be constant");
} }
auto inType = self.getType().cast<RankedTensorType>(); auto inType = cast<RankedTensorType>(self.getType());
auto inputRank = inType.getRank(); auto inputRank = inType.getRank();
auto outType = getTypeConverter() auto outType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
dim0 = toPositiveDim(dim0, inputRank); dim0 = toPositiveDim(dim0, inputRank);
if (!isValidDim(dim0, inputRank)) { if (!isValidDim(dim0, inputRank)) {
@ -691,7 +688,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self); rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
return success(); return success();
} }
@ -701,7 +698,7 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
AtenSizeIntOp op, OpAdaptor adaptor, AtenSizeIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return op.emitError("only tensor types are currently supported"); return op.emitError("only tensor types are currently supported");
@ -739,7 +736,7 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
Value other = adaptor.getOther(); Value other = adaptor.getOther();
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
// promote self and other types // promote self and other types
self = hlo::promoteType(rewriter, op.getLoc(), self, outType); self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
other = hlo::promoteType(rewriter, op.getLoc(), other, outType); other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
@ -764,10 +761,9 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
AtenBroadcastToOp op, OpAdaptor adaptor, AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
auto outType = getTypeConverter() auto outType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
if (options.enableStaticShape && selfTy.hasStaticShape()) { if (options.enableStaticShape && selfTy.hasStaticShape()) {
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType); Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
@ -831,10 +827,9 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
// Not a ranked tensor type // Not a ranked tensor type
auto inType = self.getType().dyn_cast<RankedTensorType>(); auto inType = dyn_cast<RankedTensorType>(self.getType());
auto outType = getTypeConverter() auto outType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
if (!inType) if (!inType)
return op.emitError("only ranked tensor types with static shapes are " return op.emitError("only ranked tensor types with static shapes are "
"currently supported"); "currently supported");
@ -861,15 +856,14 @@ template <>
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite( LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
ValueTensorLiteralOp op, OpAdaptor adaptor, ValueTensorLiteralOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
// Tensors with integer types need to be converted to signless integer // Tensors with integer types need to be converted to signless integer
// element type. All tensors with element types other than integer can reuse // element type. All tensors with element types other than integer can reuse
// existing elements attribute. // existing elements attribute.
// TODO: what about unsigned integer? // TODO: what about unsigned integer?
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) { if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
Type builtinTensorElemTy = resultType.getElementType(); Type builtinTensorElemTy = resultType.getElementType();
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
@ -892,9 +886,8 @@ template <>
LogicalResult ConvertAtenOp<AtenTensorIntOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenTensorIntOp>::matchAndRewrite(
AtenTensorIntOp op, OpAdaptor adaptor, AtenTensorIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Type outElementType = resultType.getElementType(); Type outElementType = resultType.getElementType();
Value innerValue = adaptor.getT(); Value innerValue = adaptor.getT();
Value stablehloTensor = Value stablehloTensor =
@ -910,10 +903,10 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
AtenReciprocalOp op, OpAdaptor adaptor, AtenReciprocalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
auto outTy = auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!inputTy.getElementType().isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(inputTy.getElementType())) {
return op.emitError("only floating-point datatype legalization supported " return op.emitError("only floating-point datatype legalization supported "
"for AtenReciprocalOp"); "for AtenReciprocalOp");
} }
@ -929,9 +922,9 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
AtenPowTensorScalarOp op, OpAdaptor adaptor, AtenPowTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
auto lhsType = lhs.getType().dyn_cast<TensorType>(); auto lhsType = dyn_cast<TensorType>(lhs.getType());
Value rhs = adaptor.getExponent(); Value rhs = adaptor.getExponent();
TensorType rhsType = rhs.getType().dyn_cast<TensorType>(); TensorType rhsType = dyn_cast<TensorType>(rhs.getType());
if (!lhsType) if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
@ -1002,9 +995,8 @@ template <>
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite( LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
PrimNumToTensorScalarOp op, OpAdaptor adaptor, PrimNumToTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
RankedTensorType outputType = getTypeConverter() RankedTensorType outputType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
auto outputElemType = outputType.getElementType(); auto outputElemType = outputType.getElementType();
Value stablehloTensor = hlo::scalarToStablehloTensor( Value stablehloTensor = hlo::scalarToStablehloTensor(
rewriter, op, adaptor.getA(), outputElemType); rewriter, op, adaptor.getA(), outputElemType);
@ -1018,8 +1010,7 @@ LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
AtenScalarImplicitOp op, OpAdaptor adaptor, AtenScalarImplicitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc(); Location loc = op.getLoc();
Type inputDtype = Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
op.getA().getType().template cast<BaseTensorType>().getDtype();
Type resultType = Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType()); this->getTypeConverter()->convertType(op->getResult(0).getType());
auto result = rewriter.create<tensor::ExtractOp>(loc, adaptor.getA()); auto result = rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
@ -1037,7 +1028,7 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return op.emitError("only tensor types are currently supported"); return op.emitError("only tensor types are currently supported");
@ -1055,7 +1046,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
AtenReluOp op, OpAdaptor adaptor, AtenReluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = cast<RankedTensorType>(lhs.getType());
auto lhsElemTy = lhsTy.getElementType(); auto lhsElemTy = lhsTy.getElementType();
if (!isa<mlir::FloatType>(lhsElemTy)) { if (!isa<mlir::FloatType>(lhsElemTy)) {
@ -1080,7 +1071,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) { if (!inputTy) {
return op.emitError("only ranked tensor type is supported."); return op.emitError("only ranked tensor type is supported.");
} }
@ -1103,11 +1094,11 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
AtenLog2Op op, OpAdaptor adaptor, AtenLog2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) { if (!inputTy) {
return op.emitError("only ranked tensor type is supported."); return op.emitError("only ranked tensor type is supported.");
} }
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>(); auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input); auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input);
@ -1124,12 +1115,12 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
AtenLog10Op op, OpAdaptor adaptor, AtenLog10Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) { if (!inputTy) {
return op.emitError("only ranked tensor type is supported."); return op.emitError("only ranked tensor type is supported.");
} }
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>(); auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input); auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input);
@ -1146,8 +1137,8 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
AtenErfOp op, OpAdaptor adaptor, AtenErfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputType = input.getType().cast<TensorType>(); auto inputType = cast<TensorType>(input.getType());
if (!inputType.getElementType().isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(inputType.getElementType())) {
return rewriter.notifyMatchFailure(op, "only float tensor is supported"); return rewriter.notifyMatchFailure(op, "only float tensor is supported");
} }
rewriter.replaceOpWithNewOp<chlo::ErfOp>( rewriter.replaceOpWithNewOp<chlo::ErfOp>(
@ -1161,7 +1152,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
AtenBatchNormOp op, OpAdaptor adaptor, AtenBatchNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getInput(); Value input = adaptor.getInput();
auto inputTy = input.getType().cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
Value weight = adaptor.getWeight(); Value weight = adaptor.getWeight();
Value bias = adaptor.getBias(); Value bias = adaptor.getBias();
Value runningMean = adaptor.getRunningMean(); Value runningMean = adaptor.getRunningMean();
@ -1174,10 +1165,10 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
// all of NC, NCL, NCHW, NCDHW's feature index is 1. // all of NC, NCL, NCHW, NCDHW's feature index is 1.
int64_t feature_index = 1; int64_t feature_index = 1;
if (!inputTy.getElementType().template isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(inputTy.getElementType())) {
return op.emitError("only input tensor of float type is supported"); return op.emitError("only input tensor of float type is supported");
} }
auto inputElemTy = inputTy.getElementType().cast<mlir::FloatType>(); auto inputElemTy = cast<mlir::FloatType>(inputTy.getElementType());
Value channelDim = Value channelDim =
rewriter.create<tensor::DimOp>(op->getLoc(), input, feature_index); rewriter.create<tensor::DimOp>(op->getLoc(), input, feature_index);
@ -1220,20 +1211,20 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
inputTy.getElementType())); inputTy.getElementType()));
} }
auto weightTy = weight.getType().cast<RankedTensorType>(); auto weightTy = cast<RankedTensorType>(weight.getType());
auto biasTy = bias.getType().cast<RankedTensorType>(); auto biasTy = cast<RankedTensorType>(bias.getType());
auto runningMeanTy = runningMean.getType().cast<RankedTensorType>(); auto runningMeanTy = cast<RankedTensorType>(runningMean.getType());
auto runningVarTy = runningVar.getType().cast<RankedTensorType>(); auto runningVarTy = cast<RankedTensorType>(runningVar.getType());
if (weightTy.getRank() != 1 || biasTy.getRank() != 1 || if (weightTy.getRank() != 1 || biasTy.getRank() != 1 ||
runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) { runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expect weight, bias, running_mean and running_var to be rank 1"); op, "expect weight, bias, running_mean and running_var to be rank 1");
} }
if (!weightTy.getElementType().template isa<mlir::FloatType>() || if (!isa<mlir::FloatType>(weightTy.getElementType()) ||
!biasTy.getElementType().template isa<mlir::FloatType>() || !isa<mlir::FloatType>(biasTy.getElementType()) ||
!runningMeanTy.getElementType().template isa<mlir::FloatType>() || !isa<mlir::FloatType>(runningMeanTy.getElementType()) ||
!runningVarTy.getElementType().template isa<mlir::FloatType>()) { !isa<mlir::FloatType>(runningVarTy.getElementType())) {
return op.emitError("only float weight/bias/runningMean/runningVar tensor " return op.emitError("only float weight/bias/runningMean/runningVar tensor "
"of float type is supported"); "of float type is supported");
} }
@ -1261,8 +1252,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
// supported mixed types, like input type is fp16 and weight type is fp32. // supported mixed types, like input type is fp16 and weight type is fp32.
if (inputTy.getElementType() != weightTy.getElementType()) { if (inputTy.getElementType() != weightTy.getElementType()) {
RankedTensorType convertedType = inputTy; RankedTensorType convertedType = inputTy;
if (weightTy.getElementType().cast<FloatType>().getWidth() > if (cast<FloatType>(weightTy.getElementType()).getWidth() >
inputTy.getElementType().cast<FloatType>().getWidth()) { cast<FloatType>(inputTy.getElementType()).getWidth()) {
convertedType = RankedTensorType::get(inputTy.getShape(), convertedType = RankedTensorType::get(inputTy.getShape(),
weightTy.getElementType()); weightTy.getElementType());
} }
@ -1302,8 +1293,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
// supported mixed types, like input type is fp16 and weight type is fp32. // supported mixed types, like input type is fp16 and weight type is fp32.
if (inputTy.getElementType() != weightTy.getElementType()) { if (inputTy.getElementType() != weightTy.getElementType()) {
RankedTensorType convertedType = inputTy; RankedTensorType convertedType = inputTy;
if (weightTy.getElementType().cast<FloatType>().getWidth() > if (cast<FloatType>(weightTy.getElementType()).getWidth() >
inputTy.getElementType().cast<FloatType>().getWidth()) { cast<FloatType>(inputTy.getElementType()).getWidth()) {
convertedType = RankedTensorType::get(inputTy.getShape(), convertedType = RankedTensorType::get(inputTy.getShape(),
weightTy.getElementType()); weightTy.getElementType());
} }
@ -1340,7 +1331,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
AtenNativeLayerNormOp op, OpAdaptor adaptor, AtenNativeLayerNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getInput(); Value input = adaptor.getInput();
auto inputTy = input.getType().cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
auto inputShape = inputTy.getShape(); auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank(); auto inputRank = inputTy.getRank();
Value weight = adaptor.getWeight(); Value weight = adaptor.getWeight();
@ -1365,12 +1356,12 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
failed(checkNotNone(rewriter, op, bias))) { failed(checkNotNone(rewriter, op, bias))) {
return op->emitError("none weight or bias is unsupported"); return op->emitError("none weight or bias is unsupported");
} }
auto weightTy = weight.getType().cast<RankedTensorType>(); auto weightTy = cast<RankedTensorType>(weight.getType());
auto biasTy = bias.getType().cast<RankedTensorType>(); auto biasTy = cast<RankedTensorType>(bias.getType());
if (!inputTy.getElementType().isa<mlir::FloatType>() || if (!isa<mlir::FloatType>(inputTy.getElementType()) ||
!biasTy.getElementType().isa<mlir::FloatType>() || !isa<mlir::FloatType>(biasTy.getElementType()) ||
!weightTy.getElementType().isa<mlir::FloatType>()) { !isa<mlir::FloatType>(weightTy.getElementType())) {
return op->emitError("currently only float data type are supported"); return op->emitError("currently only float data type are supported");
} }
int64_t normalizedShapeRank = normalizedShape.size(); int64_t normalizedShapeRank = normalizedShape.size();
@ -1423,7 +1414,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
SmallVector<APFloat> oneConstVec( SmallVector<APFloat> oneConstVec(
numFeatureDimSize, numFeatureDimSize,
APFloat( APFloat(
inputTy.getElementType().cast<mlir::FloatType>().getFloatSemantics(), cast<mlir::FloatType>(inputTy.getElementType()).getFloatSemantics(),
1)); 1));
auto oneOrZeroConstType = auto oneOrZeroConstType =
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
@ -1443,9 +1434,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
// Reshape back // Reshape back
auto outputTy = auto outputTy =
getTypeConverter()->convertType(op.getType(0)).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
auto outputMeanOrVarTy = auto outputMeanOrVarTy =
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
auto output = rewriter.create<stablehlo::DynamicReshapeOp>( auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
@ -1482,7 +1473,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
AtenCatOp op, OpAdaptor adaptor, AtenCatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
int64_t dim; int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -1516,7 +1507,7 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
AtenNumelOp op, OpAdaptor adaptor, AtenNumelOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf(); auto self = adaptor.getSelf();
auto selfTy = self.getType().dyn_cast<RankedTensorType>(); auto selfTy = dyn_cast<RankedTensorType>(self.getType());
size_t rank = selfTy.getRank(); size_t rank = selfTy.getRank();
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
@ -1544,7 +1535,7 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
AtenClampOp op, OpAdaptor adaptor, AtenClampOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto inputElemType = inputType.getElementType(); auto inputElemType = inputType.getElementType();
Value minValue = adaptor.getMin(); Value minValue = adaptor.getMin();
Value maxValue = adaptor.getMax(); Value maxValue = adaptor.getMax();
@ -1716,7 +1707,7 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto outType = auto outType =
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>(); cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
if (!outType) { if (!outType) {
return op.emitError("only tensor type is supported"); return op.emitError("only tensor type is supported");
} }
@ -1764,15 +1755,15 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
AtenPowTensorTensorOp op, OpAdaptor adaptor, AtenPowTensorTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
auto lhsTy = lhs.getType().cast<TensorType>(); auto lhsTy = cast<TensorType>(lhs.getType());
Value rhs = adaptor.getExponent(); Value rhs = adaptor.getExponent();
auto rhsTy = rhs.getType().cast<TensorType>(); auto rhsTy = cast<TensorType>(rhs.getType());
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError("only Tensor types supported"); return op.emitError("only Tensor types supported");
auto outTy = auto outTy =
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>(); cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
@ -1790,12 +1781,12 @@ LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
Value generator = adaptor.getGenerator(); Value generator = adaptor.getGenerator();
Location loc = op.getLoc(); Location loc = op.getLoc();
if (!generator.getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default " op, "The generator has to be None because only global default "
"generator is supported"); "generator is supported");
auto elements = self.getType().cast<RankedTensorType>().getShape(); auto elements = cast<RankedTensorType>(self.getType()).getShape();
if (llvm::any_of(elements, if (llvm::any_of(elements,
[](int64_t dim) { return dim == ShapedType::kDynamic; })) [](int64_t dim) { return dim == ShapedType::kDynamic; }))
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
@ -1824,14 +1815,14 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
// The pin_memory should be either `False` or `none`. // The pin_memory should be either `False` or `none`.
bool pinMemory; bool pinMemory;
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() && if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) pinMemory))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: pin_memory must be either None or false"); op, "unimplemented: pin_memory must be either None or false");
// Only `none`, `contiguous` and `preserve` memory_format is supported. // Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
int64_t memoryFormat; int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1844,7 +1835,7 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
"memory_format is supported"); "memory_format is supported");
} }
if (!op.getDevice().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getDevice().getType())) {
std::string device; std::string device;
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1853,7 +1844,7 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
// TODO: Add support for non-strided layout. // TODO: Add support for non-strided layout.
// torch.layout is by default strided i.e. 0. // torch.layout is by default strided i.e. 0.
if (!op.getLayout().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getLayout().getType())) {
int64_t tensorLayout; int64_t tensorLayout;
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1876,9 +1867,9 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
auto resultType = auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
Type resultElementType; Type resultElementType;
if (op.getDtype().getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(op.getDtype().getType())) {
resultElementType = resultType.getElementType(); resultElementType = resultType.getElementType();
} else { } else {
int64_t dtypeInt; int64_t dtypeInt;
@ -1931,7 +1922,7 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
AtenFillScalarOp op, OpAdaptor adaptor, AtenFillScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto dtype = outType.getElementType(); auto dtype = outType.getElementType();
Value scalarTensor = Value scalarTensor =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype); hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype);
@ -1951,7 +1942,7 @@ LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
SmallVector<int64_t> dims; SmallVector<int64_t> dims;
if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) { if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) {

View File

@ -64,7 +64,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
loc, rewriter.getIntegerAttr(intType, 1)); loc, rewriter.getIntegerAttr(intType, 1));
// sliceSizes // sliceSizes
auto inputRankTy = input.getType().dyn_cast<RankedTensorType>(); auto inputRankTy = dyn_cast<RankedTensorType>(input.getType());
auto inputRank = inputRankTy.getRank(); auto inputRank = inputRankTy.getRank();
SmallVector<Value, 4> sliceSizes; SmallVector<Value, 4> sliceSizes;
sliceSizes.reserve(inputRank); sliceSizes.reserve(inputRank);
@ -85,7 +85,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
for (int64_t r = 0; r < axis; ++r) { for (int64_t r = 0; r < axis; ++r) {
offsetDims.push_back(r); offsetDims.push_back(r);
} }
auto indicesRankTy = indices.getType().dyn_cast<RankedTensorType>(); auto indicesRankTy = dyn_cast<RankedTensorType>(indices.getType());
auto indicesRank = indicesRankTy.getRank(); auto indicesRank = indicesRankTy.getRank();
for (int64_t r = axis + 1; r < inputRank; ++r) { for (int64_t r = axis + 1; r < inputRank; ++r) {
offsetDims.push_back(r + indicesRank - 1); offsetDims.push_back(r + indicesRank - 1);
@ -132,8 +132,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
SmallVector<Value> &strides) { SmallVector<Value> &strides) {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
RankedTensorType inputType = RankedTensorType inputType = cast<RankedTensorType>(input.getType());
input.getType().template cast<RankedTensorType>();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@ -161,7 +160,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
int64_t step; int64_t step;
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
if (!op.getStep().getType().template isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(op.getStep().getType()))
return op->emitError("unimplemented: step is not constant"); return op->emitError("unimplemented: step is not constant");
step = 1; step = 1;
} }
@ -225,7 +224,7 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
// concat index tensor into to indices tensor for concat // concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) { for (size_t i = 0; i < indexTensors.size(); i++) {
auto indexTensor = indexTensors[i]; auto indexTensor = indexTensors[i];
auto indexTensorType = indexTensor.getType().cast<RankedTensorType>(); auto indexTensorType = cast<RankedTensorType>(indexTensor.getType());
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) { for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
if (size == kUnknownSize) if (size == kUnknownSize)
return failure(); return failure();
@ -249,7 +248,7 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
SmallVector<Value> broadcastedIndices; SmallVector<Value> broadcastedIndices;
Type indexElemTy = Type indexElemTy =
indexTensors[0].getType().cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(indexTensors[0].getType()).getElementType();
RankedTensorType bcastIndexType = RankedTensorType bcastIndexType =
RankedTensorType::get(indicesShape, indexElemTy); RankedTensorType::get(indicesShape, indexElemTy);
for (auto indexTensor : indexTensors) { for (auto indexTensor : indexTensors) {
@ -290,7 +289,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
AtenEmbeddingOp op, OpAdaptor adaptor, AtenEmbeddingOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto weight = adaptor.getWeight(); auto weight = adaptor.getWeight();
auto weightTy = weight.getType().cast<RankedTensorType>(); auto weightTy = cast<RankedTensorType>(weight.getType());
if (!weightTy) if (!weightTy)
return op.emitError("only ranked tensor types are supported"); return op.emitError("only ranked tensor types are supported");
@ -332,17 +331,17 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
Value indices = adaptor.getIndices(); Value indices = adaptor.getIndices();
Value offsets = adaptor.getOffsets(); Value offsets = adaptor.getOffsets();
auto weightTy = weight.getType().cast<RankedTensorType>(); auto weightTy = cast<RankedTensorType>(weight.getType());
if (weightTy && weightTy.hasStaticShape() && weightTy.getRank() != 2) if (weightTy && weightTy.hasStaticShape() && weightTy.getRank() != 2)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "weight must be rank 2 tensor with static shapes"); op, "weight must be rank 2 tensor with static shapes");
auto indicesTy = indices.getType().cast<RankedTensorType>(); auto indicesTy = cast<RankedTensorType>(indices.getType());
if (indicesTy && indicesTy.hasStaticShape() && indicesTy.getRank() != 1) if (indicesTy && indicesTy.hasStaticShape() && indicesTy.getRank() != 1)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "indices must be a vector with static shapes"); op, "indices must be a vector with static shapes");
auto offsetsTy = offsets.getType().cast<RankedTensorType>(); auto offsetsTy = cast<RankedTensorType>(offsets.getType());
if (offsetsTy && offsetsTy.getRank() != 1 && offsetsTy.hasStaticShape() && if (offsetsTy && offsetsTy.getRank() != 1 && offsetsTy.hasStaticShape() &&
offsetsTy.getShape()[0] == 1) offsetsTy.getShape()[0] == 1)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -485,7 +484,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
AtenIndexSelectOp op, OpAdaptor adaptor, AtenIndexSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf(); auto self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return op.emitError("only ranked tensor types are supported"); return op.emitError("only ranked tensor types are supported");
int64_t dim; int64_t dim;
@ -514,8 +513,8 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
Location loc = op->getLoc(); Location loc = op->getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
Value index = adaptor.getIndex(); Value index = adaptor.getIndex();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto indexType = index.getType().cast<RankedTensorType>(); auto indexType = cast<RankedTensorType>(index.getType());
auto indexElemType = indexType.getElementType(); auto indexElemType = indexType.getElementType();
if (indexType.getRank() != inputType.getRank()) { if (indexType.getRank() != inputType.getRank()) {
@ -623,7 +622,7 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
} }
Value src = adaptor.getSrc(); Value src = adaptor.getSrc();
auto srcType = src.getType().cast<RankedTensorType>(); auto srcType = cast<RankedTensorType>(src.getType());
int64_t srcRank = srcType.getRank(); int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize); SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
auto abstractSrcType = RankedTensorType::get( auto abstractSrcType = RankedTensorType::get(
@ -651,9 +650,9 @@ public:
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
Value index = adaptor.getIndex(); Value index = adaptor.getIndex();
Value src = adaptor.getSrc(); Value src = adaptor.getSrc();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto indexType = index.getType().cast<RankedTensorType>(); auto indexType = cast<RankedTensorType>(index.getType());
auto srcType = src.getType().cast<RankedTensorType>(); auto srcType = cast<RankedTensorType>(src.getType());
auto indexElemType = indexType.getElementType(); auto indexElemType = indexType.getElementType();
if (indexType.getRank() != inputType.getRank() || if (indexType.getRank() != inputType.getRank() ||
@ -789,9 +788,9 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc(); Location loc = op->getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTensorType = input.getType().cast<RankedTensorType>(); auto inputTensorType = cast<RankedTensorType>(input.getType());
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto outShape = outType.getShape(); auto outShape = outType.getShape();
Value indexList = op.getIndices(); Value indexList = op.getIndices();
SmallVector<Value> indicesTorchType; SmallVector<Value> indicesTorchType;
@ -857,10 +856,10 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
Value values = adaptor.getValues(); Value values = adaptor.getValues();
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank(); int64_t inputRank = inputType.getRank();
auto valuesType = values.getType().cast<RankedTensorType>(); auto valuesType = cast<RankedTensorType>(values.getType());
auto valuesShape = valuesType.getShape(); auto valuesShape = valuesType.getShape();
bool accumulate; bool accumulate;
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) { if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) {

View File

@ -32,7 +32,7 @@ namespace {
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
ArrayRef<int64_t> shape, ArrayRef<Value> dimSizes, ArrayRef<int64_t> shape, ArrayRef<Value> dimSizes,
ArrayRef<int64_t> broadcastDims) { ArrayRef<int64_t> broadcastDims) {
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>(); auto tensorTy = dyn_cast<RankedTensorType>(tensor.getType());
auto loc = op->getLoc(); auto loc = op->getLoc();
Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes); Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
@ -48,7 +48,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
ArrayRef<int64_t> inpTransDims) { ArrayRef<int64_t> inpTransDims) {
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto rank = inputTy.getRank(); auto rank = inputTy.getRank();
auto transDims = hlo::toPositiveDims(inpTransDims, rank); auto transDims = hlo::toPositiveDims(inpTransDims, rank);
auto inpShape = inputTy.getShape(); auto inpShape = inputTy.getShape();
@ -70,8 +70,8 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
int64_t lhsResultDim, int64_t rhsResultDim, int64_t lhsResultDim, int64_t rhsResultDim,
int64_t lhsContractingDim, int64_t lhsContractingDim,
int64_t rhsContractingDim) { int64_t rhsContractingDim) {
auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>(); auto lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>(); auto rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
auto oldLhsShape = lhsTy.getShape(); auto oldLhsShape = lhsTy.getShape();
auto oldRhsShape = rhsTy.getShape(); auto oldRhsShape = rhsTy.getShape();
@ -129,8 +129,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
size_t dimSizeIndexBits) { size_t dimSizeIndexBits) {
Value lhs = inpLhs; Value lhs = inpLhs;
Value rhs = inpRhs; Value rhs = inpRhs;
auto lhsRankTy = inpLhs.getType().dyn_cast<RankedTensorType>(); auto lhsRankTy = dyn_cast<RankedTensorType>(inpLhs.getType());
auto rhsRankTy = inpRhs.getType().dyn_cast<RankedTensorType>(); auto rhsRankTy = dyn_cast<RankedTensorType>(inpRhs.getType());
auto lhsRank = lhsRankTy.getRank(); auto lhsRank = lhsRankTy.getRank();
auto rhsRank = rhsRankTy.getRank(); auto rhsRank = rhsRankTy.getRank();
@ -177,8 +177,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
return; return;
} }
lhsShape = lhs.getType().cast<RankedTensorType>().getShape(); lhsShape = cast<RankedTensorType>(lhs.getType()).getShape();
rhsShape = rhs.getType().cast<RankedTensorType>().getShape(); rhsShape = cast<RankedTensorType>(rhs.getType()).getShape();
// check shape compatibility, check if we should broadcast // check shape compatibility, check if we should broadcast
// first, we should got a new batch shape. Check from (0, nBatchDims) // first, we should got a new batch shape. Check from (0, nBatchDims)
@ -266,8 +266,8 @@ public:
LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor, LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &lhs, ConversionPatternRewriter &rewriter, Value &lhs,
Value &rhs, Value &output) const { Value &rhs, Value &output) const {
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = cast<RankedTensorType>(lhs.getType());
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = cast<RankedTensorType>(rhs.getType());
auto lhsRank = lhsTy.getRank(); auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank(); auto rhsRank = rhsTy.getRank();
@ -370,10 +370,10 @@ public:
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Value &lhs, Value &rhs) const override { Value &lhs, Value &rhs) const override {
lhs = adaptor.getSelf(); lhs = adaptor.getSelf();
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = cast<RankedTensorType>(lhs.getType());
rhs = adaptor.getOther(); rhs = adaptor.getOther();
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = cast<RankedTensorType>(rhs.getType());
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError( return op.emitError(
@ -393,10 +393,10 @@ public:
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Value &lhs, Value &rhs) const override { Value &lhs, Value &rhs) const override {
lhs = adaptor.getSelf(); lhs = adaptor.getSelf();
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = cast<RankedTensorType>(lhs.getType());
rhs = adaptor.getMat2(); rhs = adaptor.getMat2();
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = cast<RankedTensorType>(rhs.getType());
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError( return op.emitError(
@ -429,10 +429,10 @@ public:
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Value &lhs, Value &rhs) const override { Value &lhs, Value &rhs) const override {
lhs = adaptor.getInput(); lhs = adaptor.getInput();
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = cast<RankedTensorType>(lhs.getType());
rhs = adaptor.getWeight(); rhs = adaptor.getWeight();
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = cast<RankedTensorType>(rhs.getType());
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError( return op.emitError(
@ -464,16 +464,15 @@ public:
auto biasTy = bias.getType(); auto biasTy = bias.getType();
// StableHLO does not mandate that elementwise op tensors need to be ranked. // StableHLO does not mandate that elementwise op tensors need to be ranked.
if (!biasTy.template isa<Torch::NoneType>() && if (!isa<Torch::NoneType>(biasTy) && !isa<RankedTensorType>(biasTy))
!biasTy.template isa<RankedTensorType>())
return op.emitError("only ranked tensor types are supported in StableHLO " return op.emitError("only ranked tensor types are supported in StableHLO "
"matmul for bias tensor"); "matmul for bias tensor");
// weight.T // weight.T
rhs = getPermutedTensor(rewriter, op, rhs, {1, 0}); rhs = getPermutedTensor(rewriter, op, rhs, {1, 0});
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = cast<RankedTensorType>(lhs.getType());
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = cast<RankedTensorType>(rhs.getType());
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(), auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
rhsTy.getRank() - lhsTy.getRank()); rhsTy.getRank() - lhsTy.getRank());
@ -503,7 +502,7 @@ public:
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
Value matmulPlusBias = matmulOutput; Value matmulPlusBias = matmulOutput;
if (!biasTy.template isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(biasTy)) {
// Bias addition broadcasts to the matmul output shape. // Bias addition broadcasts to the matmul output shape.
matmulPlusBias = rewriter matmulPlusBias = rewriter
.create<chlo::BroadcastAddOp>( .create<chlo::BroadcastAddOp>(
@ -525,7 +524,7 @@ public:
Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op, Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op,
Value weight, int64_t groups) const { Value weight, int64_t groups) const {
auto weightTy = weight.getType().cast<RankedTensorType>(); auto weightTy = cast<RankedTensorType>(weight.getType());
auto weightElemTy = weightTy.getElementType(); auto weightElemTy = weightTy.getElementType();
auto rank = weightTy.getRank(); auto rank = weightTy.getRank();
const auto &options = getOptions(); const auto &options = getOptions();
@ -588,8 +587,8 @@ public:
ArrayRef<int64_t> dilation, ArrayRef<int64_t> dilation,
ArrayRef<int64_t> outputPadding, ArrayRef<int64_t> outputPadding,
int64_t groups) const { int64_t groups) const {
auto inputTy = input.getType().cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
auto weightTy = weight.getType().cast<RankedTensorType>(); auto weightTy = cast<RankedTensorType>(weight.getType());
auto weightShape = weightTy.getShape(); auto weightShape = weightTy.getShape();
auto nDims = inputTy.getRank(); auto nDims = inputTy.getRank();
@ -727,11 +726,11 @@ public:
Value weight = adaptor.getWeight(); Value weight = adaptor.getWeight();
// The input shape is [N, C, H, W] // The input shape is [N, C, H, W]
auto inputTy = input.getType().template cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
// The weight shape is [OC, (IC//G), KH, KW] // The weight shape is [OC, (IC//G), KH, KW]
// If transposed is set to true, // If transposed is set to true,
// the weight shape changes to [IC, (OC//G), KH, KW] // the weight shape changes to [IC, (OC//G), KH, KW]
auto weightTy = weight.getType().template cast<RankedTensorType>(); auto weightTy = cast<RankedTensorType>(weight.getType());
auto outTy = getTypeConverter() auto outTy = getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template cast<RankedTensorType>(); .template cast<RankedTensorType>();
@ -819,11 +818,11 @@ public:
} }
// Handle bias // Handle bias
if (!bias.getType().cast<RankedTensorType>()) { if (!cast<RankedTensorType>(bias.getType())) {
return op.emitError("bias provided but not a ranked tensor"); return op.emitError("bias provided but not a ranked tensor");
} }
auto biasTy = bias.getType().cast<RankedTensorType>(); auto biasTy = cast<RankedTensorType>(bias.getType());
if (!biasTy.getElementType().isIntOrFloat()) { if (!biasTy.getElementType().isIntOrFloat()) {
return op.emitError("only floating-point or integer datatype " return op.emitError("only floating-point or integer datatype "
"legalization for bias supported"); "legalization for bias supported");

View File

@ -81,12 +81,12 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
AtenMaxPool2dOp op, OpAdaptor adaptor, AtenMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank(); auto inputRank = inputTy.getRank();
auto outTy = auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (inputRank <= 2) { if (inputRank <= 2) {
return op.emitError( return op.emitError(
@ -176,14 +176,14 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
auto inputShape = inputTy.getShape(); auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank(); auto inputRank = inputTy.getRank();
auto outValTy = auto outValTy =
getTypeConverter()->convertType(op.getType(0)).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
auto outIdxTy = auto outIdxTy =
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
if (inputRank <= 2) { if (inputRank <= 2) {
return op.emitError( return op.emitError(
@ -366,7 +366,7 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
RankedTensorType inputTy = input.getType().cast<RankedTensorType>(); RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
Type inputElemTy = inputTy.getElementType(); Type inputElemTy = inputTy.getElementType();
int64_t inputRank = inputTy.getRank(); int64_t inputRank = inputTy.getRank();
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter() RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
@ -539,11 +539,11 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
AtenCumsumOp op, OpAdaptor adaptor, AtenCumsumOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
auto outTy = auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
inputTy = input.getType().cast<RankedTensorType>(); inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank(); auto inputRank = inputTy.getRank();
auto inputShape = inputTy.getShape(); auto inputShape = inputTy.getShape();

View File

@ -126,7 +126,7 @@ static std::optional<ValueRange>
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
ArrayRef<Value> inputShapeVec, int64_t dim, ArrayRef<Value> inputShapeVec, int64_t dim,
size_t dimSizeIndexBits) { size_t dimSizeIndexBits) {
auto inputTy = input.getType().template cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
if (!inputTy) { if (!inputTy) {
return std::nullopt; return std::nullopt;
} }
@ -249,7 +249,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
AtenArgmaxOp op, OpAdaptor adaptor, AtenArgmaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().template cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO"); op, "only Tensor types supported in StableHLO");
@ -321,7 +321,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
AtenMaxDimOp op, OpAdaptor adaptor, AtenMaxDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO"); op, "only Tensor types supported in StableHLO");
@ -410,7 +410,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
AtenSumOp op, OpAdaptor adaptor, AtenSumOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto outTy = getTypeConverter() auto outTy = getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template dyn_cast<RankedTensorType>(); .template dyn_cast<RankedTensorType>();
@ -423,7 +423,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
auto dstElemTy = outTy.getElementType(); auto dstElemTy = outTy.getElementType();
input = input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy); rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>(); inputTy = dyn_cast<RankedTensorType>(input.getType());
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) { if (!inputElemTy.isIntOrFloat()) {
@ -626,7 +626,7 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
AtenProdOp op, OpAdaptor adaptor, AtenProdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto outTy = getTypeConverter() auto outTy = getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template dyn_cast<RankedTensorType>(); .template dyn_cast<RankedTensorType>();
@ -639,7 +639,7 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
auto dstElemTy = outTy.getElementType(); auto dstElemTy = outTy.getElementType();
input = input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy); rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>(); inputTy = dyn_cast<RankedTensorType>(input.getType());
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) { if (!inputElemTy.isIntOrFloat()) {
@ -699,7 +699,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
AtenMaxOp op, OpAdaptor adaptor, AtenMaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO"); op, "only Tensor types supported in StableHLO");
@ -762,7 +762,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
AtenMinOp op, OpAdaptor adaptor, AtenMinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO"); op, "only Tensor types supported in StableHLO");
@ -825,7 +825,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
AtenSumDimIntListOp op, OpAdaptor adaptor, AtenSumDimIntListOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto outTy = getTypeConverter() auto outTy = getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template dyn_cast<RankedTensorType>(); .template dyn_cast<RankedTensorType>();
@ -838,7 +838,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
auto dstElemTy = outTy.getElementType(); auto dstElemTy = outTy.getElementType();
input = input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy); rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>(); inputTy = dyn_cast<RankedTensorType>(input.getType());
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) { if (!inputElemTy.isIntOrFloat()) {
@ -958,7 +958,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
const TorchToStablehloOptions &options = getOptions(); const TorchToStablehloOptions &options = getOptions();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputType = input.getType().dyn_cast<RankedTensorType>(); auto inputType = dyn_cast<RankedTensorType>(input.getType());
if (!inputType) { if (!inputType) {
return op.emitError( return op.emitError(
"only ranked tensor input supported in AtenFrobeniusNormDimOp"); "only ranked tensor input supported in AtenFrobeniusNormDimOp");
@ -1070,7 +1070,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
const TorchToStablehloOptions &options = getOptions(); const TorchToStablehloOptions &options = getOptions();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputType = input.getType().dyn_cast<RankedTensorType>(); auto inputType = dyn_cast<RankedTensorType>(input.getType());
if (!inputType) { if (!inputType) {
return op.emitError( return op.emitError(
"only ranked tensor input supported in AtenLinalgVectorNormOp"); "only ranked tensor input supported in AtenLinalgVectorNormOp");
@ -1078,7 +1078,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
int64_t inputRank = inputType.getRank(); int64_t inputRank = inputType.getRank();
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto outElemType = outType.getElementType(); auto outElemType = outType.getElementType();
if (!isa<mlir::FloatType>(outElemType)) { if (!isa<mlir::FloatType>(outElemType)) {
return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp"); return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp");

View File

@ -144,7 +144,7 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Value promoteType(PatternRewriter &rewriter, Location loc, Value input, Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
TensorType outType) { TensorType outType) {
TensorType in_type = input.getType().cast<TensorType>(); TensorType in_type = cast<TensorType>(input.getType());
if (in_type.getElementType() != outType.getElementType()) { if (in_type.getElementType() != outType.getElementType()) {
TensorType promotedType = TensorType promotedType =
@ -162,7 +162,7 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
// dimension, the dimension sizes must either be equal, one of them is 1, or // dimension, the dimension sizes must either be equal, one of them is 1, or
// one of them does not exist. // one of them does not exist.
Operation *op = input.getDefiningOp(); Operation *op = input.getDefiningOp();
TensorType in_type = input.getType().dyn_cast<TensorType>(); TensorType in_type = dyn_cast<TensorType>(input.getType());
if (in_type.getElementType() != outType.getElementType()) { if (in_type.getElementType() != outType.getElementType()) {
TensorType promoted_type = TensorType promoted_type =
@ -217,7 +217,7 @@ FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
Operation *op, Value value, Operation *op, Value value,
ArrayRef<int64_t> inpDims, ArrayRef<int64_t> inpDims,
size_t dimSizeIndexBits) { size_t dimSizeIndexBits) {
auto valueTy = value.getType().dyn_cast<RankedTensorType>(); auto valueTy = dyn_cast<RankedTensorType>(value.getType());
if (!valueTy) { if (!valueTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "getDimSizesOfTensor(): the input is not a ranked tensor"); op, "getDimSizesOfTensor(): the input is not a ranked tensor");
@ -240,7 +240,7 @@ FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter, FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
Operation *op, Value value, Operation *op, Value value,
size_t dimSizeIndexBits) { size_t dimSizeIndexBits) {
auto valueTy = value.getType().dyn_cast<RankedTensorType>(); auto valueTy = dyn_cast<RankedTensorType>(value.getType());
if (!valueTy) { if (!valueTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "getDimSizesOfTensor(): the input is not a ranked tensor"); op, "getDimSizesOfTensor(): the input is not a ranked tensor");
@ -279,7 +279,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
op, "unsqueeze dimensions must be specified in order"); op, "unsqueeze dimensions must be specified in order");
auto loc = op->getLoc(); auto loc = op->getLoc();
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>(); auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape(); auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits); Type intType = rewriter.getIntegerType(dimSizeIndexBits);
auto one = rewriter.create<arith::ConstantOp>( auto one = rewriter.create<arith::ConstantOp>(

View File

@ -72,7 +72,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
SmallVector<Value, 4> endIndices; SmallVector<Value, 4> endIndices;
SmallVector<Value, 4> strides; SmallVector<Value, 4> strides;
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
size_t rank = inputTy.getRank(); size_t rank = inputTy.getRank();
startIndices.reserve(rank); startIndices.reserve(rank);
endIndices.reserve(rank); endIndices.reserve(rank);
@ -116,7 +116,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
std::optional<Value> stepOpt, int64_t dim, std::optional<Value> stepOpt, int64_t dim,
size_t dimSizeIndexBits) { size_t dimSizeIndexBits) {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto rank = inputTy.getRank(); auto rank = inputTy.getRank();
dim = (dim + rank) % rank; dim = (dim + rank) % rank;
@ -168,8 +168,7 @@ public:
LogicalResult LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto rankType = auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
adaptor.getSelf().getType().template dyn_cast<RankedTensorType>();
if (!rankType) if (!rankType)
return op.emitError("Only ranked tensor types are currently supported"); return op.emitError("Only ranked tensor types are currently supported");
@ -233,11 +232,11 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
AtenSliceTensorOp op, OpAdaptor adaptor, AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf(); auto self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return op.emitError("only ranked tensor types are supported"); return op.emitError("only ranked tensor types are supported");
auto outTy = auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
int64_t dim; int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -275,7 +274,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
AtenSqueezeOp op, OpAdaptor adaptor, AtenSqueezeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf(); auto self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return op.emitError("only ranked tensor types are supported"); return op.emitError("only ranked tensor types are supported");
@ -318,7 +317,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
AtenSqueezeDimOp op, OpAdaptor adaptor, AtenSqueezeDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf(); auto self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return op.emitError("only ranked tensor types are supported"); return op.emitError("only ranked tensor types are supported");
@ -369,7 +368,7 @@ template <>
LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
AtenUnsqueezeOp op, OpAdaptor adaptor, AtenUnsqueezeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) { if (!selfType) {
return op.emitError("only tensor types are currently supported"); return op.emitError("only tensor types are currently supported");
} }
@ -378,7 +377,7 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("dim must be a Scalar constant"); return op->emitError("dim must be a Scalar constant");
int64_t inputRank = int64_t inputRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank(); cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
dim = toPositiveDim(dim, inputRank + 1); dim = toPositiveDim(dim, inputRank + 1);
if (!isValidDim(dim, inputRank + 1)) if (!isValidDim(dim, inputRank + 1))
return rewriter.notifyMatchFailure(op, "dim is statically invalid"); return rewriter.notifyMatchFailure(op, "dim is statically invalid");
@ -397,7 +396,7 @@ template <>
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite( LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
PrimsCollapseOp op, OpAdaptor adaptor, PrimsCollapseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getA().getType());
if (!selfType) { if (!selfType) {
return op.emitError("only tensor types are currently supported"); return op.emitError("only tensor types are currently supported");
} }

View File

@ -89,8 +89,8 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
Value indices, Value src, Value indices, Value src,
int64_t dim) { int64_t dim) {
// Get information on types for inputs // Get information on types for inputs
RankedTensorType indexType = indices.getType().cast<RankedTensorType>(); RankedTensorType indexType = cast<RankedTensorType>(indices.getType());
RankedTensorType srcSelf = src.getType().cast<RankedTensorType>(); RankedTensorType srcSelf = cast<RankedTensorType>(src.getType());
// Store location for insertions // Store location for insertions
Location loc = src.getLoc(); Location loc = src.getLoc();
@ -219,7 +219,7 @@ static Value createTMTensorScatterOp(
llvm::ArrayRef<int64_t> dimensionsMap, bool uniqueIndices, llvm::ArrayRef<int64_t> dimensionsMap, bool uniqueIndices,
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) { function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap); auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap);
auto originalTensorType = original.getType().cast<RankedTensorType>(); auto originalTensorType = cast<RankedTensorType>(original.getType());
Type originalElementType = originalTensorType.getElementType(); Type originalElementType = originalTensorType.getElementType();
auto scatterOp = b.create<TMTensor::ScatterOp>( auto scatterOp = b.create<TMTensor::ScatterOp>(
loc, originalTensorType, ValueRange{updates, indices}, loc, originalTensorType, ValueRange{updates, indices},
@ -241,8 +241,8 @@ static Value createTMTensorScanOp(
OpBuilder &b, Location loc, Value input, Value output, Value accumulator, OpBuilder &b, Location loc, Value input, Value output, Value accumulator,
int64_t dim, bool inclusive, int64_t dim, bool inclusive,
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) { function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto accType = accumulator.getType().cast<RankedTensorType>(); auto accType = cast<RankedTensorType>(accumulator.getType());
Type elementType = inputType.getElementType(); Type elementType = inputType.getElementType();
auto scanOp = b.create<TMTensor::ScanOp>( auto scanOp = b.create<TMTensor::ScanOp>(
loc, TypeRange{inputType, accType}, input, loc, TypeRange{inputType, accType}, input,
@ -287,7 +287,7 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
// Step 3. Create comparison op which will be used as the sorting predicate. // Step 3. Create comparison op which will be used as the sorting predicate.
Value compareOp; Value compareOp;
if (auto intType = elementTypes[0].dyn_cast<mlir::IntegerType>()) { if (auto intType = dyn_cast<mlir::IntegerType>(elementTypes[0])) {
// Case for using arith::CmpIOp. // Case for using arith::CmpIOp.
arith::CmpIPredicate ge = arith::CmpIPredicate::sge; arith::CmpIPredicate ge = arith::CmpIPredicate::sge;
arith::CmpIPredicate le = arith::CmpIPredicate::sle; arith::CmpIPredicate le = arith::CmpIPredicate::sle;
@ -329,9 +329,9 @@ public:
Value index = adaptor.getIndex(); Value index = adaptor.getIndex();
Value src = adaptor.getSrc(); Value src = adaptor.getSrc();
RankedTensorType selfType = self.getType().cast<RankedTensorType>(); RankedTensorType selfType = cast<RankedTensorType>(self.getType());
RankedTensorType indexType = index.getType().cast<RankedTensorType>(); RankedTensorType indexType = cast<RankedTensorType>(index.getType());
RankedTensorType srcType = src.getType().cast<RankedTensorType>(); RankedTensorType srcType = cast<RankedTensorType>(src.getType());
if (selfType.getRank() != indexType.getRank() || if (selfType.getRank() != indexType.getRank() ||
indexType.getRank() != srcType.getRank()) indexType.getRank() != srcType.getRank())
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -385,7 +385,7 @@ public:
// TODO: Add a check to verify that the input tensor elements are all // TODO: Add a check to verify that the input tensor elements are all
// non-negative. // non-negative.
// Check whether the input is a 1-d tensor of integer type or not. // Check whether the input is a 1-d tensor of integer type or not.
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType = cast<RankedTensorType>(input.getType());
if (inputType.getRank() != 1 || if (inputType.getRank() != 1 ||
!inputType.getElementType().isa<mlir::IntegerType>()) !inputType.getElementType().isa<mlir::IntegerType>())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -394,7 +394,7 @@ public:
// Check whether the input tensor element type is i64 or not. // Check whether the input tensor element type is i64 or not.
IntegerType inputIntegerType = IntegerType inputIntegerType =
inputType.getElementType().cast<IntegerType>(); cast<IntegerType>(inputType.getElementType());
if (inputIntegerType.getWidth() != 64) if (inputIntegerType.getWidth() != 64)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, op,
@ -409,7 +409,7 @@ public:
SmallVector<int64_t> maxTensorSizes; SmallVector<int64_t> maxTensorSizes;
ValueTensorType maxTensorType = ValueTensorType::get( ValueTensorType maxTensorType = ValueTensorType::get(
context, llvm::ArrayRef(maxTensorSizes), context, llvm::ArrayRef(maxTensorSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype()); cast<ValueTensorType>(torchTypeInput.getType()).getDtype());
Value maxTensor = Value maxTensor =
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput); rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
maxTensor = typeConverter->materializeTargetConversion( maxTensor = typeConverter->materializeTargetConversion(
@ -432,7 +432,7 @@ public:
makeShapeTorchCompatible(inputType.getShape())[0], 1}; makeShapeTorchCompatible(inputType.getShape())[0], 1};
ValueTensorType expandInputType = ValueTensorType::get( ValueTensorType expandInputType = ValueTensorType::get(
context, llvm::ArrayRef(expandedInputSizes), context, llvm::ArrayRef(expandedInputSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype()); cast<ValueTensorType>(torchTypeInput.getType()).getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>( Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1)); loc, rewriter.getI64IntegerAttr(1));
Value expandedInputTensor = rewriter.create<AtenUnsqueezeOp>( Value expandedInputTensor = rewriter.create<AtenUnsqueezeOp>(
@ -571,7 +571,7 @@ Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
} }
BaseTensorType unsqueezedTensorType = BaseTensorType unsqueezedTensorType =
indices[0].getType().cast<BaseTensorType>(); cast<BaseTensorType>(indices[0].getType());
Value indicesTorchList = b.create<PrimListConstructOp>( Value indicesTorchList = b.create<PrimListConstructOp>(
loc, Torch::ListType::get(unsqueezedTensorType), indices); loc, Torch::ListType::get(unsqueezedTensorType), indices);
llvm::SmallVector<int64_t, 2> concatShape{ llvm::SmallVector<int64_t, 2> concatShape{
@ -691,7 +691,7 @@ public:
auto inputType = cast<ValueTensorType>(input.getType()); auto inputType = cast<ValueTensorType>(input.getType());
auto valuesType = cast<ValueTensorType>(values.getType()); auto valuesType = cast<ValueTensorType>(values.getType());
int64_t inputRank = inputType.getSizes().size(); int64_t inputRank = inputType.getSizes().size();
auto valuesTensorType = op.getValues().getType().cast<BaseTensorType>(); auto valuesTensorType = cast<BaseTensorType>(op.getValues().getType());
auto resultType = typeConverter->convertType(op->getResult(0).getType()) auto resultType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>(); .cast<RankedTensorType>();
@ -902,9 +902,9 @@ public:
Value gradOutput = adaptor.getGradOutput(); Value gradOutput = adaptor.getGradOutput();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
RankedTensorType gradOutputType = RankedTensorType gradOutputType =
gradOutput.getType().cast<RankedTensorType>(); cast<RankedTensorType>(gradOutput.getType());
Type gradOutputElemType = gradOutputType.getElementType(); Type gradOutputElemType = gradOutputType.getElementType();
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType = cast<RankedTensorType>(input.getType());
Type inputElemType = inputType.getElementType(); Type inputElemType = inputType.getElementType();
int64_t tensorOperandRank = inputType.getRank(); int64_t tensorOperandRank = inputType.getRank();
@ -914,7 +914,7 @@ public:
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
indices = typeConverter->materializeTargetConversion( indices = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(indices.getType()), indices); rewriter, loc, typeConverter->convertType(indices.getType()), indices);
RankedTensorType indicesType = indices.getType().cast<RankedTensorType>(); RankedTensorType indicesType = cast<RankedTensorType>(indices.getType());
Type indicesElemType = indicesType.getElementType(); Type indicesElemType = indicesType.getElementType();
// The element type of the `input` and `grad_output` should be same. // The element type of the `input` and `grad_output` should be same.
@ -1100,11 +1100,11 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
RankedTensorType selfType = RankedTensorType selfType =
adaptor.getSelf().getType().cast<RankedTensorType>(); cast<RankedTensorType>(adaptor.getSelf().getType());
RankedTensorType indexType = RankedTensorType indexType =
adaptor.getIndex().getType().cast<RankedTensorType>(); cast<RankedTensorType>(adaptor.getIndex().getType());
RankedTensorType srcType = RankedTensorType srcType =
adaptor.getSrc().getType().cast<RankedTensorType>(); cast<RankedTensorType>(adaptor.getSrc().getType());
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
@ -1324,7 +1324,7 @@ public:
// Step 1. Fetch Input to sort. // Step 1. Fetch Input to sort.
Value inputTensor = adaptor.getSelf(); Value inputTensor = adaptor.getSelf();
auto inputType = inputTensor.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(inputTensor.getType());
unsigned inputRank = inputType.getRank(); unsigned inputRank = inputType.getRank();
// Step 2. Fetch dimension to perform sort in. // Step 2. Fetch dimension to perform sort in.
@ -1414,7 +1414,7 @@ public:
.cast<RankedTensorType>(); .cast<RankedTensorType>();
Type elementType = resultType.getElementType(); Type elementType = resultType.getElementType();
Type inputElementType = Type inputElementType =
input.getType().cast<RankedTensorType>().getElementType(); cast<RankedTensorType>(input.getType()).getElementType();
// Converting the input element type to the result's element type. // Converting the input element type to the result's element type.
// The only possible mismatch would be when the input element type is an // The only possible mismatch would be when the input element type is an
@ -1486,7 +1486,7 @@ public:
Value isCausal = op.getIsCausal(); Value isCausal = op.getIsCausal();
Value scale = op.getScale(); Value scale = op.getScale();
Type elementType = Type elementType =
adaptor.getQuery().getType().cast<ShapedType>().getElementType(); cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
// Verify inputs (only support defaults) // Verify inputs (only support defaults)
if (!mask.getType().isa<Torch::NoneType>()) if (!mask.getType().isa<Torch::NoneType>())
@ -1557,10 +1557,9 @@ public:
key = collapseBatch(key); key = collapseBatch(key);
value = collapseBatch(value); value = collapseBatch(value);
SmallVector<int64_t> outSizes( SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
query.getType().cast<ShapedType>().getShape());
SmallVector<int64_t> valueSizes( SmallVector<int64_t> valueSizes(
value.getType().cast<ShapedType>().getShape()); cast<ShapedType>(value.getType()).getShape());
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1]; outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
SmallVector<Value> outSizesDynamic( SmallVector<Value> outSizesDynamic(
getTensorSizes(rewriter, op.getLoc(), query)); getTensorSizes(rewriter, op.getLoc(), query));

View File

@ -79,9 +79,9 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc(); auto loc = op.getLoc();
auto operand = adaptor.getOperands()[0]; auto operand = adaptor.getOperands()[0];
auto operandTy = operand.getType().cast<RankedTensorType>(); auto operandTy = cast<RankedTensorType>(operand.getType());
auto resultTy = auto resultTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
int64_t rank = operandTy.getRank(); int64_t rank = operandTy.getRank();
if (rank == 0) { if (rank == 0) {

View File

@ -43,7 +43,7 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = cast<TensorType>(self.getType());
if (!selfTy) if (!selfTy)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -93,9 +93,9 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
auto lhsTy = lhs.getType().cast<TensorType>(); auto lhsTy = cast<TensorType>(lhs.getType());
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
auto rhsTy = rhs.getType().cast<TensorType>(); auto rhsTy = cast<TensorType>(rhs.getType());
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -235,15 +235,15 @@ public:
// alpha : scalar: i32/i64/f32 // alpha : scalar: i32/i64/f32
// output: tensor: tensor<i32/i64/f32> // output: tensor: tensor<i32/i64/f32>
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
auto lhsType = lhs.getType().dyn_cast<TensorType>(); auto lhsType = dyn_cast<TensorType>(lhs.getType());
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
auto rhsType = rhs.getType().dyn_cast<TensorType>(); auto rhsType = dyn_cast<TensorType>(rhs.getType());
if (!lhsType) if (!lhsType)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA"); "Only Tensor types supported in TOSA");
if (auto lhsElemTy = lhsType.getElementType().dyn_cast<IntegerType>()) { if (auto lhsElemTy = dyn_cast<IntegerType>(lhsType.getElementType())) {
if (lhsElemTy.getWidth() > 64) if (lhsElemTy.getWidth() > 64)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Integers with widths greater than 64 are not supported"); op, "Integers with widths greater than 64 are not supported");
@ -284,7 +284,7 @@ public:
op->getLoc(), op->getLoc(),
RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs); RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs);
// reinitialize right value type to tensor<i32/f32> // reinitialize right value type to tensor<i32/f32>
rhsType = rhs.getType().dyn_cast<TensorType>(); rhsType = dyn_cast<TensorType>(rhs.getType());
} }
auto rhsTensor = rhsType ? rhs : rhsAsTensor; auto rhsTensor = rhsType ? rhs : rhsAsTensor;
@ -337,9 +337,9 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
auto lhsTy = lhs.getType().dyn_cast<TensorType>(); auto lhsTy = dyn_cast<TensorType>(lhs.getType());
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
auto rhsTy = rhs.getType().dyn_cast<TensorType>(); auto rhsTy = dyn_cast<TensorType>(rhs.getType());
if (!lhsTy) if (!lhsTy)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -409,7 +409,7 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
auto lhsType = lhs.getType().dyn_cast<TensorType>(); auto lhsType = dyn_cast<TensorType>(lhs.getType());
if (!lhsType) if (!lhsType)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -430,7 +430,7 @@ public:
} else { } else {
Value rhsAsTensor; Value rhsAsTensor;
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
auto rhsType = rhs.getType().dyn_cast<TensorType>(); auto rhsType = dyn_cast<TensorType>(rhs.getType());
if (!rhsType) { if (!rhsType) {
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
rhsAsTensor, outElemTy, {}))) { rhsAsTensor, outElemTy, {}))) {
@ -469,9 +469,9 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
auto lhsTy = lhs.getType().dyn_cast<TensorType>(); auto lhsTy = dyn_cast<TensorType>(lhs.getType());
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
auto rhsTy = rhs.getType().dyn_cast<TensorType>(); auto rhsTy = dyn_cast<TensorType>(rhs.getType());
if (!lhsTy) if (!lhsTy)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -497,7 +497,7 @@ public:
// auto result; // auto result;
Value result; Value result;
if (outType.getElementType().template isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(outType.getElementType())) {
// The input to the reciprocal is an integer sometimes, and we may need to // The input to the reciprocal is an integer sometimes, and we may need to
// promote it to a floating point. Per TOSA specification, the input types // promote it to a floating point. Per TOSA specification, the input types
// can only be floating point for tosa::ReciprocalOp. // can only be floating point for tosa::ReciprocalOp.
@ -538,7 +538,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
AtenTanhOp op, OpAdaptor adaptor, AtenTanhOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = cast<TensorType>(self.getType());
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) { if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<tosa::TanhOp>( rewriter.replaceOpWithNewOp<tosa::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self); op, getTypeConverter()->convertType(op.getType()), self);
@ -555,7 +555,7 @@ LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
AtenSigmoidOp op, OpAdaptor adaptor, AtenSigmoidOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = cast<TensorType>(self.getType());
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) { if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>( rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
op, getTypeConverter()->convertType(op.getType()), self); op, getTypeConverter()->convertType(op.getType()), self);
@ -572,7 +572,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
AtenReluOp op, OpAdaptor adaptor, AtenReluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = cast<TensorType>(self.getType());
// Maps to tosa.clamp which has both int and fp limits. // Maps to tosa.clamp which has both int and fp limits.
int64_t clampMin = 0; int64_t clampMin = 0;
@ -602,7 +602,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = cast<TensorType>(self.getType());
if (!selfTy.getElementType().isa<mlir::FloatType>()) { if (!selfTy.getElementType().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported"); op, "Only floating-point datatype legalization currently supported");
@ -660,7 +660,7 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = cast<TensorType>(self.getType());
if (!selfTy) if (!selfTy)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -713,7 +713,7 @@ class ConvertAtenMultipleDimsReductionOp
"non-const dim parameter unsupported"); "non-const dim parameter unsupported");
int64_t N = reduceDims.size(); int64_t N = reduceDims.size();
int64_t inputRank = int64_t inputRank =
adaptor.getSelf().getType().template cast<RankedTensorType>().getRank(); cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
for (unsigned i = 0; i < N; i++) { for (unsigned i = 0; i < N; i++) {
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank); reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
if (!isValidDim(reduceDims[i], inputRank)) if (!isValidDim(reduceDims[i], inputRank))
@ -751,7 +751,7 @@ class ConvertAtenOneDimReductionOp
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported"); "non-const dim parameter unsupported");
int64_t inputRank = int64_t inputRank =
adaptor.getSelf().getType().template cast<RankedTensorType>().getRank(); cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
reduceDim = toPositiveDim(reduceDim, inputRank); reduceDim = toPositiveDim(reduceDim, inputRank);
if (!isValidDim(reduceDim, inputRank)) if (!isValidDim(reduceDim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid"); return rewriter.notifyMatchFailure(op, "dim is statically invalid");
@ -782,7 +782,7 @@ public:
ElementsAttr &reduceDimsAttr, ElementsAttr &reduceDimsAttr,
bool &keepDims) const override { bool &keepDims) const override {
auto self = adaptor.getSelf(); auto self = adaptor.getSelf();
auto selfTy = self.getType().template cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
// Select all dims to reduce // Select all dims to reduce
SmallVector<int64_t, 4> reduceDims; SmallVector<int64_t, 4> reduceDims;
@ -804,7 +804,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().template cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -835,7 +835,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
// Create a single instance of tosa.argmax. // Create a single instance of tosa.argmax.
// Multiple dims require chained construct. // Multiple dims require chained construct.
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value { auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
auto inputTy = input.getType().cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
auto inputShape = makeShapeTorchCompatible(inputTy.getShape()); auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
SmallVector<int64_t> outputShapeArr = {}; SmallVector<int64_t> outputShapeArr = {};
int32_t i = 0; int32_t i = 0;
@ -865,7 +865,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
// Convert the final index to i64 for backend finalization, However, i64 // Convert the final index to i64 for backend finalization, However, i64
// is not a defined type for tosa.cast, so using arith.extsi instead. // is not a defined type for tosa.cast, so using arith.extsi instead.
auto castToInt64 = [&](Value result) -> LogicalResult { auto castToInt64 = [&](Value result) -> LogicalResult {
auto resTy = result.getType().cast<ShapedType>(); auto resTy = cast<ShapedType>(result.getType());
if (!resTy) if (!resTy)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Argmax: Result is not a shaped type"); "Argmax: Result is not a shaped type");
@ -915,7 +915,7 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().template cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1010,7 +1010,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().template cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1021,7 +1021,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
op, "Only floating-point datatype legalization supported"); op, "Only floating-point datatype legalization supported");
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).template cast<TensorType>(); cast<TensorType>(getTypeConverter()->convertType(op.getType()));
Value expTensor; Value expTensor;
Value expScalar = op.getExponent(); Value expScalar = op.getExponent();
@ -1063,8 +1063,8 @@ public:
ConversionPatternRewriter &rewriter, Value &lhs, ConversionPatternRewriter &rewriter, Value &lhs,
Value &rhs, Value &output) const { Value &rhs, Value &output) const {
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = cast<RankedTensorType>(lhs.getType());
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = cast<RankedTensorType>(rhs.getType());
auto lhsRank = lhsTy.getRank(); auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank(); auto rhsRank = rhsTy.getRank();
@ -1097,7 +1097,7 @@ public:
// construct the input and output reshaping logic. // construct the input and output reshaping logic.
auto getRankBroadcastedShape = [&](Value tensor, auto getRankBroadcastedShape = [&](Value tensor,
bool isRHS) -> SmallVector<int64_t> { bool isRHS) -> SmallVector<int64_t> {
auto tensorTy = tensor.getType().cast<TensorType>(); auto tensorTy = cast<TensorType>(tensor.getType());
auto tensorShape = makeShapeTorchCompatible(tensorTy.getShape()); auto tensorShape = makeShapeTorchCompatible(tensorTy.getShape());
auto tensorRank = tensorTy.getRank(); auto tensorRank = tensorTy.getRank();
@ -1151,7 +1151,7 @@ public:
// TOSA matmul is performed on two 3D inputs and generates a 3D output. // TOSA matmul is performed on two 3D inputs and generates a 3D output.
// Lower ranked tensors are dim-1 reshaped up to 3D // Lower ranked tensors are dim-1 reshaped up to 3D
auto reshapeUpTo3DTensor = [&](Value tensor) -> Value { auto reshapeUpTo3DTensor = [&](Value tensor) -> Value {
auto tensorTy = tensor.getType().cast<TensorType>(); auto tensorTy = cast<TensorType>(tensor.getType());
auto rank = tensorTy.getRank(); auto rank = tensorTy.getRank();
assert(rank <= 3 && "reshapeUpTo3D tensor must receive rank <= 3"); assert(rank <= 3 && "reshapeUpTo3D tensor must receive rank <= 3");
@ -1440,9 +1440,9 @@ public:
} }
auto matmulLhsShape = makeShapeTorchCompatible( auto matmulLhsShape = makeShapeTorchCompatible(
matmulLhs.getType().template cast<RankedTensorType>().getShape()); cast<RankedTensorType>(matmulLhs.getType()).getShape());
auto matmulRhsShape = makeShapeTorchCompatible( auto matmulRhsShape = makeShapeTorchCompatible(
matmulRhs.getType().template cast<RankedTensorType>().getShape()); cast<RankedTensorType>(matmulRhs.getType()).getShape());
// The reshape/transpose should ensure the tosa.matmul always has same // The reshape/transpose should ensure the tosa.matmul always has same
// batch size for either matrix. If if shapes are dynamic, they'll be // batch size for either matrix. If if shapes are dynamic, they'll be
@ -1642,10 +1642,10 @@ public:
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Value &lhs, Value &rhs) const override { Value &lhs, Value &rhs) const override {
lhs = adaptor.getSelf(); lhs = adaptor.getSelf();
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = cast<RankedTensorType>(lhs.getType());
rhs = adaptor.getOther(); rhs = adaptor.getOther();
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = cast<RankedTensorType>(rhs.getType());
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1666,10 +1666,10 @@ public:
Value &lhs, Value &rhs) const override { Value &lhs, Value &rhs) const override {
lhs = adaptor.getSelf(); lhs = adaptor.getSelf();
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = cast<RankedTensorType>(lhs.getType());
rhs = adaptor.getMat2(); rhs = adaptor.getMat2();
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = cast<RankedTensorType>(rhs.getType());
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1703,10 +1703,10 @@ public:
Value &lhs, Value &rhs) const override { Value &lhs, Value &rhs) const override {
lhs = adaptor.getInput(); lhs = adaptor.getInput();
auto lhsTy = lhs.getType().cast<RankedTensorType>(); auto lhsTy = cast<RankedTensorType>(lhs.getType());
rhs = adaptor.getWeight(); rhs = adaptor.getWeight();
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = cast<RankedTensorType>(rhs.getType());
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1744,14 +1744,13 @@ public:
auto biasTy = bias.getType(); auto biasTy = bias.getType();
// TOSA does not mandate that elementwise op tensors need to be ranked. // TOSA does not mandate that elementwise op tensors need to be ranked.
if (!biasTy.template isa<Torch::NoneType>() && if (!isa<Torch::NoneType>(biasTy) && !isa<TensorType>(biasTy))
!biasTy.template isa<TensorType>())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types supported in GEMM to TOSA for bias tensor"); op, "Only tensor types supported in GEMM to TOSA for bias tensor");
// RHS must have its last two dims transposed prior to matrix // RHS must have its last two dims transposed prior to matrix
// multiplication. // multiplication.
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = cast<RankedTensorType>(rhs.getType());
auto rhsRank = rhsTy.getRank(); auto rhsRank = rhsTy.getRank();
auto rhsShape = makeShapeTorchCompatible(rhsTy.getShape()); auto rhsShape = makeShapeTorchCompatible(rhsTy.getShape());
auto rhsElemTy = rhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType();
@ -1789,7 +1788,7 @@ public:
"Failed to perform matmul operation"); "Failed to perform matmul operation");
Value matmulPlusBias = matmulOutput; Value matmulPlusBias = matmulOutput;
if (!biasTy.template isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(biasTy)) {
// Bias addition broadcasts to the matmul output shape. // Bias addition broadcasts to the matmul output shape.
matmulPlusBias = matmulPlusBias =
rewriter rewriter
@ -1818,7 +1817,7 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
auto otherScalar = op.getOther(); auto otherScalar = op.getOther();
auto alphaScalar = op.getAlpha(); auto alphaScalar = op.getAlpha();
auto selfTy = self.getType().cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Rsub"); op, "Only ranked tensor types supported in TOSA Rsub");
@ -1867,8 +1866,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
auto input = adaptor.getInput(); auto input = adaptor.getInput();
auto weight = adaptor.getWeight(); auto weight = adaptor.getWeight();
auto inputTy = input.getType().cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
auto weightTy = weight.getType().cast<RankedTensorType>(); auto weightTy = cast<RankedTensorType>(weight.getType());
auto outputTy = getTypeConverter() auto outputTy = getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template cast<RankedTensorType>(); .template cast<RankedTensorType>();
@ -1893,7 +1892,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
// Bias is optional. TOSA mandates a zero tensor here, so construct one if // Bias is optional. TOSA mandates a zero tensor here, so construct one if
// required. // required.
auto bias = adaptor.getBias(); auto bias = adaptor.getBias();
if (adaptor.getBias().getType().template isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(adaptor.getBias().getType())) {
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and // TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt to // accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
// define a 48-bit int. // define a 48-bit int.
@ -1909,7 +1908,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
.value(); .value();
} }
} else { } else {
if (!bias.getType().cast<RankedTensorType>()) if (!cast<RankedTensorType>(bias.getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Bias provided but not a ranked tensor"); op, "Bias provided but not a ranked tensor");
} }
@ -2115,7 +2114,7 @@ LogicalResult ConvertAtenOp<AtenReshapeOp>::matchAndRewrite(
auto self = adaptor.getSelf(); auto self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Reshape"); op, "Only ranked tensor types supported in TOSA Reshape");
@ -2199,7 +2198,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a ranked tensor output // Not a ranked tensor output
if (!adaptor.getInput().getType().dyn_cast<RankedTensorType>()) if (!dyn_cast<RankedTensorType>(adaptor.getInput().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are supported"); op, "Only ranked tensor types are supported");
@ -2211,8 +2210,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
if (op.getMomentum().getType().isa<Torch::NoneType>()) if (op.getMomentum().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op, "Unsupported None for momentum"); return rewriter.notifyMatchFailure(op, "Unsupported None for momentum");
auto meanType = adaptor.getRunningMean().getType().dyn_cast<TensorType>(); auto meanType = dyn_cast<TensorType>(adaptor.getRunningMean().getType());
auto varianceType = adaptor.getRunningVar().getType().dyn_cast<TensorType>(); auto varianceType = dyn_cast<TensorType>(adaptor.getRunningVar().getType());
if (!varianceType || !meanType) if (!varianceType || !meanType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are supported"); op, "Only ranked tensor types are supported");
@ -2225,7 +2224,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
const TypeConverter *converter, Type outType, const TypeConverter *converter, Type outType,
const Value toBcast, Value &result) { const Value toBcast, Value &result) {
RankedTensorType toBcastType = RankedTensorType toBcastType =
toBcast.getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(toBcast.getType());
if (toBcastType.getRank() > 1) if (toBcastType.getRank() > 1)
return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1"); return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1");
@ -2298,11 +2297,11 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
// eventually being reshaped for broadcasting. // eventually being reshaped for broadcasting.
// Not a ranked tensor output // Not a ranked tensor output
if (!adaptor.getInput().getType().dyn_cast<RankedTensorType>()) if (!dyn_cast<RankedTensorType>(adaptor.getInput().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are supported"); op, "Only ranked tensor types are supported");
auto inputType = adaptor.getInput().getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(adaptor.getInput().getType());
if (inputType.getRank() > 4) if (inputType.getRank() > 4)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only up to 4D tensors are supported"); "Only up to 4D tensors are supported");
@ -2317,8 +2316,8 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
if (adaptor.getBias().getType().isa<Torch::NoneType>()) if (adaptor.getBias().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op, "Unsupported None for bias"); return rewriter.notifyMatchFailure(op, "Unsupported None for bias");
auto weightType = adaptor.getWeight().getType().cast<RankedTensorType>(); auto weightType = cast<RankedTensorType>(adaptor.getWeight().getType());
auto biasType = adaptor.getBias().getType().cast<RankedTensorType>(); auto biasType = cast<RankedTensorType>(adaptor.getBias().getType());
int64_t inputRank = inputType.getRank(); int64_t inputRank = inputType.getRank();
Type elemTy = inputType.getElementType(); Type elemTy = inputType.getElementType();
SmallVector<int64_t> inputTypeShape( SmallVector<int64_t> inputTypeShape(
@ -2461,7 +2460,7 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
// element type. All tensors with element types other than integer can reuse // element type. All tensors with element types other than integer can reuse
// existing elements attribute. // existing elements attribute.
// TODO: what about unsigned integer? // TODO: what about unsigned integer?
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) { if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
if (elements.getElementType().isSignedInteger()) { if (elements.getElementType().isSignedInteger()) {
Type builtinTensorElemTy = outputTy.getElementType(); Type builtinTensorElemTy = outputTy.getElementType();
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
@ -2483,7 +2482,7 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a ranked tensor type // Not a ranked tensor type
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>(); auto selfType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only ranked tensor types supported"); "Only ranked tensor types supported");
@ -2548,7 +2547,7 @@ LogicalResult ConvertAtenOp<AtenUnflattenIntOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a ranked tensor type // Not a ranked tensor type
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>(); auto selfType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
if (!selfType || !selfType.hasStaticShape()) if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, op,
@ -2602,7 +2601,7 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a ranked tensor type // Not a ranked tensor type
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>(); auto selfType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, op,
@ -2637,7 +2636,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
@ -2665,7 +2664,7 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
@ -2715,7 +2714,7 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) { if (!selfType) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
@ -2763,7 +2762,7 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
@ -2781,7 +2780,7 @@ LogicalResult ConvertAtenOp<AtenDropoutOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getInput().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getInput().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
@ -2807,7 +2806,7 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
@ -2869,7 +2868,7 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
// //
// Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4 // Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4
auto outType = x.getType().cast<TensorType>(); auto outType = cast<TensorType>(x.getType());
auto loc = op->getLoc(); auto loc = op->getLoc();
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x); auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value(); auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
@ -2949,7 +2948,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
@ -2986,7 +2985,7 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
@ -3043,7 +3042,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) { if (!selfType) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
@ -3063,7 +3062,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
} }
Value gradOutput = adaptor.getGradOutput(); Value gradOutput = adaptor.getGradOutput();
auto gradOutputType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto gradOutputType = dyn_cast<TensorType>(adaptor.getSelf().getType());
Type gradOutputElemType = gradOutputType.getElementType(); Type gradOutputElemType = gradOutputType.getElementType();
@ -3119,14 +3118,14 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
Value weight = adaptor.getWeight(); Value weight = adaptor.getWeight();
Value indices = adaptor.getIndices(); Value indices = adaptor.getIndices();
RankedTensorType outType = RankedTensorType outType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto indicesType = indices.getType().dyn_cast<RankedTensorType>(); auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
if (!indicesType || !indicesType.getElementType().isa<IntegerType>()) if (!indicesType || !indicesType.getElementType().isa<IntegerType>())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Indices must be of integer tensor type"); op, "Indices must be of integer tensor type");
auto weightType = weight.getType().cast<RankedTensorType>(); auto weightType = cast<RankedTensorType>(weight.getType());
if (weightType.getRank() != 2) if (weightType.getRank() != 2)
return op.emitError("weight must be of rank 2"); return op.emitError("weight must be of rank 2");
@ -3216,7 +3215,7 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
AtenTransposeIntOp op, OpAdaptor adaptor, AtenTransposeIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
@ -3258,12 +3257,12 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
AtenMaxDimOp op, OpAdaptor adaptor, AtenMaxDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto indicesType = auto indicesType =
getTypeConverter()->convertType(op.getType(1)).dyn_cast<TensorType>(); dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType(1)));
if (!indicesType) if (!indicesType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
@ -3334,7 +3333,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
AtenSliceTensorOp op, OpAdaptor adaptor, AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType || !selfType.hasStaticShape()) if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported"); op, "Only tensor types with static shape are supported");
@ -3406,7 +3405,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType || !selfType.hasStaticShape()) if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported"); op, "Only tensor types with static shape are supported");
@ -3500,13 +3499,13 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
// Not a tensor type. // Not a tensor type.
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
auto inputType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>(); auto inputType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
if (!inputType) if (!inputType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only RankedTensorType input are currently supported"); op, "Only RankedTensorType input are currently supported");
auto index = adaptor.getIndex(); auto index = adaptor.getIndex();
auto indexType = adaptor.getIndex().getType().dyn_cast<RankedTensorType>(); auto indexType = dyn_cast<RankedTensorType>(adaptor.getIndex().getType());
auto inputShape = inputType.getShape(); auto inputShape = inputType.getShape();
int paramsRank = inputShape.size(); int paramsRank = inputShape.size();
@ -3593,13 +3592,13 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
// Not a tensor type. // Not a tensor type.
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
auto fillValues = adaptor.getValues(); auto fillValues = adaptor.getValues();
auto valuesType = adaptor.getValues().getType().dyn_cast<TensorType>(); auto valuesType = dyn_cast<TensorType>(adaptor.getValues().getType());
if (!valuesType) if (!valuesType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
@ -3640,7 +3639,7 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Multiple None index is not support for now."); op, "Multiple None index is not support for now.");
} }
auto indexNextType = indexNext.getType().dyn_cast<RankedTensorType>(); auto indexNextType = dyn_cast<RankedTensorType>(indexNext.getType());
auto indexNextShape = indexNextType.getShape(); auto indexNextShape = indexNextType.getShape();
int64_t size = 1; int64_t size = 1;
@ -3652,7 +3651,7 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
.value(); .value();
} }
auto indexType = index.getType().dyn_cast<RankedTensorType>(); auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape(); auto indexShape = indexType.getShape();
indexesShape.push_back(makeShapeTorchCompatible(indexShape)); indexesShape.push_back(makeShapeTorchCompatible(indexShape));
indexesRank.push_back(indexType.getRank()); indexesRank.push_back(indexType.getRank());
@ -3734,7 +3733,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
// [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]] // [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]]
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
auto inputTensorType = auto inputTensorType =
adaptor.getSelf().getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
// Check input is a tensor type. // Check input is a tensor type.
if (!inputTensorType) if (!inputTensorType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -3771,7 +3770,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
for (size_t i = 0; i < indexTensors.size(); i++) { for (size_t i = 0; i < indexTensors.size(); i++) {
auto index = indexTensors[i]; auto index = indexTensors[i];
auto indexType = index.getType().dyn_cast<RankedTensorType>(); auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape(); auto indexShape = indexType.getShape();
indexesShape.push_back(makeShapeTorchCompatible(indexShape)); indexesShape.push_back(makeShapeTorchCompatible(indexShape));
indexesRank.push_back(indexType.getRank()); indexesRank.push_back(indexType.getRank());
@ -3837,7 +3836,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
// Support for multiple index // Support for multiple index
auto index = indexTensors[0]; auto index = indexTensors[0];
auto indexType = index.getType().dyn_cast<RankedTensorType>(); auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape(); auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible // index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) { if (indexType.getElementType() != rewriter.getIntegerType(32)) {
@ -3879,7 +3878,7 @@ LogicalResult ConvertAtenOp<AtenAbsOp>::matchAndRewrite(
AtenAbsOp op, OpAdaptor adaptor, AtenAbsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
@ -3896,11 +3895,11 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
auto condType = adaptor.getCondition().getType().dyn_cast<TensorType>(); auto condType = dyn_cast<TensorType>(adaptor.getCondition().getType());
if (!condType) if (!condType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types condition are currently supported"); op, "Only tensor types condition are currently supported");
@ -3919,11 +3918,11 @@ LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>(); auto otherType = dyn_cast<TensorType>(adaptor.getOther().getType());
if (!otherType) if (!otherType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types condition are currently supported"); op, "Only tensor types condition are currently supported");
@ -3955,8 +3954,8 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
op, "unimplemented: equal_nan is expected to be false"); op, "unimplemented: equal_nan is expected to be false");
// check tensor type. // check tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>(); auto otherType = dyn_cast<TensorType>(adaptor.getOther().getType());
if (!selfType || !otherType) if (!selfType || !otherType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
@ -3998,7 +3997,7 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only tensor types input are currently supported"); op, "only tensor types input are currently supported");
@ -4251,8 +4250,8 @@ LogicalResult ConvertAtenOp<AtenCopyOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
auto srcType = adaptor.getSrc().getType().dyn_cast<TensorType>(); auto srcType = dyn_cast<TensorType>(adaptor.getSrc().getType());
if (!selfType || !selfType.hasStaticShape()) if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported"); op, "Only tensor types with static shape are supported");
@ -4297,7 +4296,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType || !selfType.hasStaticShape()) if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported"); op, "Only tensor types with static shape are supported");
@ -4355,14 +4354,14 @@ LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().template cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Remainder"); op, "Only ranked tensor types supported in TOSA Remainder");
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).template cast<TensorType>(); cast<TensorType>(getTypeConverter()->convertType(op.getType()));
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) if (!outElemTy.isIntOrFloat())
@ -4438,7 +4437,7 @@ public:
// Apply the transposeDims vector on input to generate a transposed form. // Apply the transposeDims vector on input to generate a transposed form.
Value transposeTensor(AtenOpT op, ConversionPatternRewriter &rewriter, Value transposeTensor(AtenOpT op, ConversionPatternRewriter &rewriter,
Value input, ArrayRef<int32_t> transposeDims) const { Value input, ArrayRef<int32_t> transposeDims) const {
auto inputTy = input.getType().template cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
auto inputShape = makeShapeTorchCompatible(inputTy.getShape()); auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
auto inputRank = inputTy.getRank(); auto inputRank = inputTy.getRank();
@ -4462,8 +4461,7 @@ public:
Value transposePoolingInputToHwc(AtenOpT op, Value transposePoolingInputToHwc(AtenOpT op,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Value input) const { Value input) const {
auto inputRank = auto inputRank = cast<RankedTensorType>(input.getType()).getRank();
input.getType().template cast<RankedTensorType>().getRank();
SmallVector<int32_t> nchwToNhwc4DTransposeDims({0, 2, 3, 1}); SmallVector<int32_t> nchwToNhwc4DTransposeDims({0, 2, 3, 1});
SmallVector<int32_t> chwToHwc3DTransposeDims({1, 2, 0}); SmallVector<int32_t> chwToHwc3DTransposeDims({1, 2, 0});
@ -4476,7 +4474,7 @@ public:
Value transposePoolingOutputToChw(AtenOpT op, Value transposePoolingOutputToChw(AtenOpT op,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Value input) const { Value input) const {
auto inputTy = input.getType().template cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(input.getType());
auto inputRank = inputTy.getRank(); auto inputRank = inputTy.getRank();
SmallVector<int32_t> nhwcToNchw4DTransposeDims({0, 3, 1, 2}); SmallVector<int32_t> nhwcToNchw4DTransposeDims({0, 3, 1, 2});
@ -4547,7 +4545,7 @@ public:
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
Type &outputTy) const override { Type &outputTy) const override {
auto inputXchw = adaptor.getSelf(); auto inputXchw = adaptor.getSelf();
auto inputTy = inputXchw.getType().template cast<RankedTensorType>(); auto inputTy = cast<RankedTensorType>(inputXchw.getType());
if (!inputTy) if (!inputTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Adaptive avgpool requires ranked tensor input"); op, "Adaptive avgpool requires ranked tensor input");
@ -4659,7 +4657,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride, DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
DenseI64ArrayAttr &pad) { DenseI64ArrayAttr &pad) {
RankedTensorType inputTy = inputXchw.getType().cast<RankedTensorType>(); RankedTensorType inputTy = cast<RankedTensorType>(inputXchw.getType());
if (!inputTy) if (!inputTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Pooling op requires ranked tensor input"); op, "Pooling op requires ranked tensor input");
@ -4797,7 +4795,7 @@ public:
// FIXME: Handle layout, device and pin_memory. Assume dtype has been // FIXME: Handle layout, device and pin_memory. Assume dtype has been
// processed to set output type correctly? // processed to set output type correctly?
// The layout arg should be either `none` or `0` i.e. strided. // The layout arg should be either `none` or `0` i.e. strided.
if (!op.getLayout().getType().template isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getLayout().getType())) {
int64_t tensorLayout; int64_t tensorLayout;
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -4808,7 +4806,7 @@ public:
} }
bool pinMemory; bool pinMemory;
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() && if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) { pinMemory)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -4892,19 +4890,19 @@ public:
} }
// Not a tensor type. // Not a tensor type.
auto selfType = adaptor.getSelf().getType().template dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType || !outType.hasStaticShape()) if (!selfType || !outType.hasStaticShape())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, op,
"Only tensor types with static shapes input are currently supported"); "Only tensor types with static shapes input are currently supported");
auto maskType = adaptor.getMask().getType().template dyn_cast<TensorType>(); auto maskType = dyn_cast<TensorType>(adaptor.getMask().getType());
if (!maskType) if (!maskType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types mask are currently supported"); op, "Only tensor types mask are currently supported");
Value rhs = adaptor.getValue(); Value rhs = adaptor.getValue();
auto rhsType = rhs.getType().template dyn_cast<TensorType>(); auto rhsType = dyn_cast<TensorType>(rhs.getType());
Value rhsAsTensor; Value rhsAsTensor;
if (!rhsType) { // scalar if (!rhsType) { // scalar
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(),
@ -4913,11 +4911,11 @@ public:
op, "Currently only scalar constants are supported for " op, "Currently only scalar constants are supported for "
"conversion in TOSA operation"); "conversion in TOSA operation");
} else { // tensor } else { // tensor
rhsType = rhs.getType().dyn_cast<TensorType>(); rhsType = dyn_cast<TensorType>(rhs.getType());
} }
auto rhsTensor = rhsType ? rhs : rhsAsTensor; auto rhsTensor = rhsType ? rhs : rhsAsTensor;
auto rhsTensorType = rhsTensor.getType().template dyn_cast<TensorType>(); auto rhsTensorType = dyn_cast<TensorType>(rhsTensor.getType());
if (rhsTensorType.getElementType() != outElemTy) if (rhsTensorType.getElementType() != outElemTy)
rhsTensor = rewriter.create<tosa::CastOp>( rhsTensor = rewriter.create<tosa::CastOp>(
op.getLoc(), op.getLoc(),
@ -4940,7 +4938,7 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
int64_t memoryFormat; int64_t memoryFormat;
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() && if (!isa<Torch::NoneType>(op.getMemoryFormat().getType()) &&
(!matchPattern(op.getMemoryFormat(), (!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) || m_TorchConstantInt(&memoryFormat)) ||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous && (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
@ -4964,7 +4962,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
auto selfElemTy = selfTy.getElementType(); auto selfElemTy = selfTy.getElementType();
int64_t rank = selfTy.getRank(); int64_t rank = selfTy.getRank();
@ -5033,7 +5031,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
const TypeConverter *typeConverter = this->getTypeConverter(); const TypeConverter *typeConverter = this->getTypeConverter();
auto outType = auto outType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
int64_t rank = outType.getRank(); int64_t rank = outType.getRank();
int64_t dim; int64_t dim;
@ -5074,7 +5072,7 @@ LogicalResult ConvertAtenOp<AtenSqrtOp>::matchAndRewrite(
// Converts AtenSqrtOp into pow(x, 0.5) // Converts AtenSqrtOp into pow(x, 0.5)
auto self = adaptor.getSelf(); auto self = adaptor.getSelf();
auto selfTy = self.getType().dyn_cast<TensorType>(); auto selfTy = dyn_cast<TensorType>(self.getType());
if (!selfTy) if (!selfTy)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA"); "Only Tensor types supported in TOSA");

View File

@ -117,8 +117,8 @@ template <>
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter, tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
Operation *op, TensorType outType, Operation *op, TensorType outType,
Value lhs, Value rhs) { Value lhs, Value rhs) {
auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType(); auto lhsElemTy = cast<TensorType>(lhs.getType()).getElementType();
auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType(); auto rhsElemTy = cast<TensorType>(rhs.getType()).getElementType();
if (isa<mlir::FloatType>(lhsElemTy) || isa<mlir::FloatType>(rhsElemTy)) { if (isa<mlir::FloatType>(lhsElemTy) || isa<mlir::FloatType>(rhsElemTy)) {
(void)rewriter.notifyMatchFailure(op, (void)rewriter.notifyMatchFailure(op,
"tosa.div only supports integer type"); "tosa.div only supports integer type");
@ -148,8 +148,8 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
// [2,1] [[0, 3, 2],[0, 3, 1]] // [2,1] [[0, 3, 2],[0, 3, 1]]
// ]] 1*4*2 ]] 1*4*2*3 // ]] 1*4*2 ]] 1*4*2*3
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>(); auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
auto indexType = indexValue.getType().dyn_cast<RankedTensorType>(); auto indexType = dyn_cast<RankedTensorType>(indexValue.getType());
auto paramsShape = paramsType.getShape(); // [1 4 3] auto paramsShape = paramsType.getShape(); // [1 4 3]
auto indexShape = indexType.getShape(); // [1 4 2] auto indexShape = indexType.getShape(); // [1 4 2]
int paramsRank = paramsShape.size(); // 3 int paramsRank = paramsShape.size(); // 3
@ -214,8 +214,8 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
Type outType, Value paramsValue, Type outType, Value paramsValue,
Value indicesValue) { Value indicesValue) {
auto resultType = dyn_cast<ShapedType>(outType); auto resultType = dyn_cast<ShapedType>(outType);
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>(); auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>(); auto indicesType = dyn_cast<RankedTensorType>(indicesValue.getType());
if (!resultType || !paramsType || !indicesType) if (!resultType || !paramsType || !indicesType)
return std::nullopt; return std::nullopt;
@ -420,9 +420,9 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
Value paramsValue, Value indicesValue, Value paramsValue, Value indicesValue,
Value fillValues) { Value fillValues) {
auto resultType = dyn_cast<ShapedType>(outType); auto resultType = dyn_cast<ShapedType>(outType);
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>(); auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>(); auto indicesType = dyn_cast<RankedTensorType>(indicesValue.getType());
auto fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>(); auto fillValuesType = dyn_cast<RankedTensorType>(fillValues.getType());
if (!resultType || !paramsType || !indicesType) if (!resultType || !paramsType || !indicesType)
return std::nullopt; return std::nullopt;
@ -572,7 +572,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
tosaFillValuesTileOp.getResult(), tosaFillValuesTileOp.getResult(),
rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape)); rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape));
fillValues = newTosaFillValuesReshapeOp.getResult(); fillValues = newTosaFillValuesReshapeOp.getResult();
fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>(); fillValuesType = dyn_cast<RankedTensorType>(fillValues.getType());
} }
// fillK: range of each index, total number of fillInput(could be scatter) // fillK: range of each index, total number of fillInput(could be scatter)
@ -691,7 +691,7 @@ std::optional<Value> convertReduceOpCommon(
Type reduce_element_type, bool is_quantized, double input_scale, Type reduce_element_type, bool is_quantized, double input_scale,
int64_t input_zp, double output_scale, int64_t output_zp) { int64_t input_zp, double output_scale, int64_t output_zp) {
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type) if (!input_type)
return std::nullopt; return std::nullopt;
@ -754,7 +754,7 @@ convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value, RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) { ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type) if (!input_type)
return std::nullopt; return std::nullopt;
@ -769,7 +769,7 @@ convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value, RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) { ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type) if (!input_type)
return std::nullopt; return std::nullopt;
@ -784,7 +784,7 @@ convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value, RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) { ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type) if (!input_type)
return std::nullopt; return std::nullopt;
@ -799,7 +799,7 @@ convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value, RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) { ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type) if (!input_type)
return std::nullopt; return std::nullopt;
@ -814,7 +814,7 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value, RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) { ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type) if (!input_type)
return std::nullopt; return std::nullopt;
@ -840,7 +840,7 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value, RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) { ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type) if (!input_type)
return std::nullopt; return std::nullopt;
@ -863,9 +863,9 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
if (input_is_qtype) { if (input_is_qtype) {
auto input_qtype = auto input_qtype =
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>(); cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
auto output_qtype = auto output_qtype =
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>(); cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
int32_t input_shift = 20; int32_t input_shift = 20;
@ -895,7 +895,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
// op2 = mul(op1, 1.0 / num_elements_on_reduced_axis) // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type) if (!input_type)
return std::nullopt; return std::nullopt;
@ -940,9 +940,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
if (input_is_qtype) { if (input_is_qtype) {
auto input_qtype = auto input_qtype =
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>(); cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
auto output_qtype = auto output_qtype =
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>(); cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
// Combine 'div_scale' as part of output rescale // Combine 'div_scale' as part of output rescale
output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale(); output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
@ -976,7 +976,7 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value, RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) { ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type) if (!input_type)
return std::nullopt; return std::nullopt;

View File

@ -45,7 +45,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
Value input_val, double input_scale, Value input_val, double input_scale,
int64_t input_zp) { int64_t input_zp) {
// Output is always int32 type // Output is always int32 type
auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>(); auto input_type = dyn_cast<mlir::ShapedType>(input_val.getType());
assert(input_type); assert(input_type);
auto output_type = input_type.clone(rewriter.getI32Type()); auto output_type = input_type.clone(rewriter.getI32Type());
@ -58,9 +58,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
Value conv_val, ShapedType input_type, Value conv_val, ShapedType input_type,
ShapedType weight_type, ShapedType output_type) { ShapedType weight_type, ShapedType output_type) {
auto input_qtype = auto input_qtype =
input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>(); dyn_cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
auto output_qtype = output_type.getElementType() auto output_qtype =
.dyn_cast<mlir::quant::UniformQuantizedType>(); dyn_cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
double input_scale = input_qtype.getScale(); double input_scale = input_qtype.getScale();
@ -71,8 +71,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
int32_t scale_width = scale32 ? 32 : 16; int32_t scale_width = scale32 ? 32 : 16;
if (auto weight_per_tensor_qtype = if (auto weight_per_tensor_qtype =
weight_type.getElementType() dyn_cast<mlir::quant::UniformQuantizedType>(
.dyn_cast<mlir::quant::UniformQuantizedType>()) { weight_type.getElementType())) {
// Per-tensor quantization // Per-tensor quantization
double weight_scale = weight_per_tensor_qtype.getScale(); double weight_scale = weight_per_tensor_qtype.getScale();
@ -94,8 +94,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
return rescale_op.getResult(); return rescale_op.getResult();
} else if (auto weight_per_channel_qtype = } else if (auto weight_per_channel_qtype =
weight_type.getElementType() dyn_cast<mlir::quant::UniformQuantizedPerAxisType>(
.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) { weight_type.getElementType())) {
// Per-channel quantization // Per-channel quantization
SmallVector<int32_t> multiplier_arr; SmallVector<int32_t> multiplier_arr;
SmallVector<int8_t> shift_arr; SmallVector<int8_t> shift_arr;
@ -311,7 +311,7 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) {
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result) { Value src, Type destType, Value &result) {
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType(); Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
Type destElemTy = dyn_cast<TensorType>(destType).getElementType(); Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
if (failed(checkValidityOfCast(srcElemTy, destElemTy))) if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
@ -319,7 +319,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
op, "casting to result dtype is invalid or unsupported"); op, "casting to result dtype is invalid or unsupported");
if (destElemTy.isInteger(1)) { if (destElemTy.isInteger(1)) {
auto srcType = src.getType().dyn_cast<TensorType>(); auto srcType = dyn_cast<TensorType>(src.getType());
SmallVector<int64_t> srcShape(srcType.getShape()); SmallVector<int64_t> srcShape(srcType.getShape());
uint64_t num_total_elements = 1; uint64_t num_total_elements = 1;
for (int64_t a : srcShape) for (int64_t a : srcShape)
@ -355,7 +355,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
Operation *op = input.getDefiningOp(); Operation *op = input.getDefiningOp();
TensorType inType = input.getType().cast<TensorType>(); TensorType inType = cast<TensorType>(input.getType());
if (inType.getElementType() != outType.getElementType()) { if (inType.getElementType() != outType.getElementType()) {
TensorType promotedType = TensorType promotedType =

View File

@ -52,7 +52,7 @@ LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) {
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank // Generate IR: dim = dim >= 0 ? dim : dim + inputRank
Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
Value inputRank) { Value inputRank) {
assert(dim.getType().isa<IntegerType>() && assert(isa<IntegerType>(dim.getType()) &&
"dim arg of toPositiveDim must be integer type"); "dim arg of toPositiveDim must be integer type");
Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank); Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank);
Value cst0 = Value cst0 =
@ -132,7 +132,7 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) { Type elemTy) {
Value initTensor = Value initTensor =
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy); b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
RankedTensorType type = initTensor.getType().cast<RankedTensorType>(); RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
Value c0 = Value c0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType())); b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0); return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
@ -172,7 +172,7 @@ Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc, SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
Value tensor, int dim) { Value tensor, int dim) {
RankedTensorType type = tensor.getType().cast<RankedTensorType>(); RankedTensorType type = cast<RankedTensorType>(tensor.getType());
assert(dim < type.getRank() && assert(dim < type.getRank() &&
"The given dim must be smaller than tensor rank"); "The given dim must be smaller than tensor rank");
(void)type; (void)type;
@ -183,7 +183,7 @@ SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
} }
SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc, Value tensor) { SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc, Value tensor) {
RankedTensorType type = tensor.getType().cast<RankedTensorType>(); RankedTensorType type = cast<RankedTensorType>(tensor.getType());
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1); return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
} }

View File

@ -77,7 +77,7 @@ Value TMTensor::getDimValue(OpBuilder &builder, Location loc, Value v,
OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v, OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
int64_t dim) { int64_t dim) {
auto t = v.getType().cast<ShapedType>(); auto t = cast<ShapedType>(v.getType());
if (t.isDynamicDim(dim)) { if (t.isDynamicDim(dim)) {
return getDimValue(builder, loc, v, dim); return getDimValue(builder, loc, v, dim);
} }
@ -123,7 +123,7 @@ bool AttentionOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes, static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
Value rhs, ValueRange rhsSizes, Value output, Value rhs, ValueRange rhsSizes, Value output,
ValueRange outputSizes, bool transposed = false) { ValueRange outputSizes, bool transposed = false) {
auto elementType = lhs.getType().cast<MemRefType>().getElementType(); auto elementType = cast<MemRefType>(lhs.getType()).getElementType();
Value one = b.create<arith::ConstantIndexOp>(loc, 1); Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0); Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
auto rank = outputSizes.size(); auto rank = outputSizes.size();
@ -168,9 +168,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value key = getKey(); Value key = getKey();
Value value = getValue(); Value value = getValue();
Value output = getOutput(); Value output = getOutput();
auto queryType = query.getType().cast<MemRefType>(); auto queryType = cast<MemRefType>(query.getType());
auto keyType = key.getType().cast<MemRefType>(); auto keyType = cast<MemRefType>(key.getType());
auto valueType = value.getType().cast<MemRefType>(); auto valueType = cast<MemRefType>(value.getType());
auto queryRank = queryType.getRank(); auto queryRank = queryType.getRank();
auto keyRank = keyType.getRank(); auto keyRank = keyType.getRank();
auto valueRank = valueType.getRank(); auto valueRank = valueType.getRank();
@ -330,12 +330,12 @@ LogicalResult ScanOp::verify() {
if (getNumOutputs() != 2) { if (getNumOutputs() != 2) {
return emitOpError("expected two output operands"); return emitOpError("expected two output operands");
} }
if (!input().getType().isa<ShapedType>()) { if (!isa<ShapedType>(input().getType())) {
return emitOpError("expected first input element type to be shaped"); return emitOpError("expected first input element type to be shaped");
} }
auto accumulatorType = accumulator().getType().cast<ShapedType>(); auto accumulatorType = cast<ShapedType>(accumulator().getType());
auto inputType = input().getType().cast<ShapedType>(); auto inputType = cast<ShapedType>(input().getType());
auto outputType = output().getType().cast<ShapedType>(); auto outputType = cast<ShapedType>(output().getType());
ArrayRef<int64_t> inputShapes = inputType.getShape(); ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape(); ArrayRef<int64_t> outputShapes = outputType.getShape();
if (accumulatorType.getElementType() != inputType.getElementType()) { if (accumulatorType.getElementType() != inputType.getElementType()) {
@ -706,7 +706,7 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
loadIndices.push_back(Value()); loadIndices.push_back(Value());
// Populate with empty values. // Populate with empty values.
auto originalTy = original().getType().cast<ShapedType>(); auto originalTy = cast<ShapedType>(original().getType());
starts.resize(originalTy.getRank(), Value()); starts.resize(originalTy.getRank(), Value());
auto updateIvs = ivs.drop_front(1); auto updateIvs = ivs.drop_front(1);
@ -797,7 +797,7 @@ LogicalResult SortOp::verify() {
if (yieldOp.getNumOperands() != 1) { if (yieldOp.getNumOperands() != 1) {
return op->emitOpError("should yield exactly one operand"); return op->emitOpError("should yield exactly one operand");
} }
auto ty = yieldOp.getOperand(0).getType().dyn_cast<IntegerType>(); auto ty = dyn_cast<IntegerType>(yieldOp.getOperand(0).getType());
if (!ty || ty.getWidth() != 1) { if (!ty || ty.getWidth() != 1) {
return op->emitOpError("should yield i1 type"); return op->emitOpError("should yield i1 type");
} }

View File

@ -29,7 +29,7 @@ using namespace ::mlir;
using namespace ::mlir::torch::TMTensor; using namespace ::mlir::torch::TMTensor;
static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
auto memrefType = memref.getType().cast<MemRefType>(); auto memrefType = cast<MemRefType>(memref.getType());
auto alloc = b.create<memref::AllocOp>( auto alloc = b.create<memref::AllocOp>(
loc, memref::getMixedSizes(b, loc, memref), memrefType.getElementType()); loc, memref::getMixedSizes(b, loc, memref), memrefType.getElementType());
b.create<memref::CopyOp>(loc, memref, alloc); b.create<memref::CopyOp>(loc, memref, alloc);

View File

@ -80,7 +80,7 @@ struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern {
return failure(); return failure();
} }
if (llvm::any_of(scalarLoopOp->getResults(), if (llvm::any_of(scalarLoopOp->getResults(),
[&](Value v) { return v.getType().isa<ShapedType>(); })) { [&](Value v) { return isa<ShapedType>(v.getType()); })) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
scalarLoopOp, "lower to loops needs to have tensor semantics"); scalarLoopOp, "lower to loops needs to have tensor semantics");
} }

View File

@ -122,14 +122,14 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
auto func = dyn_cast<func::FuncOp>(op); auto func = dyn_cast<func::FuncOp>(op);
if (!func) if (!func)
return op->emitError() << "'torch.type_bound' must be attached to a func"; return op->emitError() << "'torch.type_bound' must be attached to a func";
TypeAttr attr = namedAttr.getValue().dyn_cast<TypeAttr>(); TypeAttr attr = dyn_cast<TypeAttr>(namedAttr.getValue());
if (!attr) if (!attr)
return op->emitError() << "'torch.type_bound' must be TypeAttr"; return op->emitError() << "'torch.type_bound' must be TypeAttr";
auto type = attr.getValue().dyn_cast<BaseTensorType>(); auto type = dyn_cast<BaseTensorType>(attr.getValue());
if (!type) if (!type)
return op->emitError() << "'torch.type_bound' must be of " return op->emitError() << "'torch.type_bound' must be of "
"!torch.tensor/!torch.vtensor type"; "!torch.tensor/!torch.vtensor type";
if (!func.getFunctionType().getInput(argIndex).isa<BaseTensorType>()) if (!isa<BaseTensorType>(func.getFunctionType().getInput(argIndex)))
return op->emitError() << "'torch.type_bound' must be attached to an " return op->emitError() << "'torch.type_bound' must be attached to an "
"argument of !torch.tensor/!torch.vtensor type"; "argument of !torch.tensor/!torch.vtensor type";
return success(); return success();

View File

@ -75,7 +75,7 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc, Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
BaseTensorType newType, BaseTensorType newType,
Value tensor) { Value tensor) {
auto originalType = tensor.getType().cast<BaseTensorType>(); auto originalType = cast<BaseTensorType>(tensor.getType());
// Adjust the static information in the type to match between the original and // Adjust the static information in the type to match between the original and
// new types. // new types.
if (!originalType.hasSameSizesAndDtype(newType)) { if (!originalType.hasSameSizesAndDtype(newType)) {
@ -87,7 +87,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
// up creating one op that converts between the value and non-value tensor // up creating one op that converts between the value and non-value tensor
// domains. If both the original and new types are both non-value tensors, // domains. If both the original and new types are both non-value tensors,
// then we do the copy by going to a value tensor and back. // then we do the copy by going to a value tensor and back.
if (tensor.getType().isa<NonValueTensorType>()) if (isa<NonValueTensorType>(tensor.getType()))
tensor = builder.create<CopyToValueTensorOp>(loc, tensor); tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
if (isa<NonValueTensorType>(newType)) if (isa<NonValueTensorType>(newType))
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor); tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
@ -96,7 +96,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
} }
bool mlir::torch::Torch::isListPotentiallyMutated(Value list) { bool mlir::torch::Torch::isListPotentiallyMutated(Value list) {
assert(list.getType().isa<Torch::ListType>()); assert(isa<Torch::ListType>(list.getType()));
return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands); return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands);
} }
@ -148,8 +148,7 @@ static Value getScalarIntValue(Value input, Location loc,
return nullptr; return nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) { if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue() auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
.cast<DenseIntElementsAttr>()
.getSplatValue<int64_t>(); .getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>( return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val)); loc, rewriter.getI64IntegerAttr(val));
@ -777,7 +776,7 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
if (getOperand(0).getType() != getResult().getType()) if (getOperand(0).getType() != getResult().getType())
return nullptr; return nullptr;
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) { if (auto tensorType = dyn_cast<BaseTensorType>(getOperand(0).getType())) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0); return getOperand(0);
} }
@ -798,11 +797,11 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
if (!matchPattern(getCopy(), m_TorchConstantBool(&copyArg)) || copyArg) if (!matchPattern(getCopy(), m_TorchConstantBool(&copyArg)) || copyArg)
return nullptr; return nullptr;
// The memory_format arg must be `none`. // The memory_format arg must be `none`.
if (!getMemoryFormat().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
return nullptr; return nullptr;
auto inputType = getSelf().getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(getSelf().getType());
auto resType = getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(getType());
// If the types aren't equal, then we can't fold. // If the types aren't equal, then we can't fold.
if (inputType != resType) if (inputType != resType)
return nullptr; return nullptr;
@ -821,7 +820,7 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
// The pin_memory arg should be either constant `False` or `none`. // The pin_memory arg should be either constant `False` or `none`.
if (!getPinMemory().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(getPinMemory().getType())) {
bool pinMemory; bool pinMemory;
if (!matchPattern(getPinMemory(), m_TorchConstantBool(&pinMemory))) if (!matchPattern(getPinMemory(), m_TorchConstantBool(&pinMemory)))
return nullptr; return nullptr;
@ -844,15 +843,15 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
// The device arg must be `none`. // The device arg must be `none`.
if (!getDevice().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(getDevice().getType()))
return nullptr; return nullptr;
// The memory_format arg must be `none`. // The memory_format arg must be `none`.
if (!getMemoryFormat().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
return nullptr; return nullptr;
auto inputType = getSelf().getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(getSelf().getType());
auto resType = getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(getType());
// If the types aren't equal, then we can't fold. // If the types aren't equal, then we can't fold.
if (inputType != resType) if (inputType != resType)
return nullptr; return nullptr;
@ -863,7 +862,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
// The layout arg should be either `none` or `0` i.e. strided. // The layout arg should be either `none` or `0` i.e. strided.
if (!getLayout().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(getLayout().getType())) {
int64_t tensorLayout; int64_t tensorLayout;
if (!matchPattern(getLayout(), m_TorchConstantInt(&tensorLayout))) if (!matchPattern(getLayout(), m_TorchConstantInt(&tensorLayout)))
return nullptr; return nullptr;
@ -882,7 +881,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
// is false // is false
patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) { patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) {
// The pin_memory arg should be either constant `False` or `none`. // The pin_memory arg should be either constant `False` or `none`.
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getPinMemory().getType())) {
bool pinMemory; bool pinMemory;
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
return failure(); return failure();
@ -891,7 +890,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
} }
// The layout arg should be either `none` or `0` i.e. strided. // The layout arg should be either `none` or `0` i.e. strided.
if (!op.getLayout().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getLayout().getType())) {
int64_t tensorLayout; int64_t tensorLayout;
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
return failure(); return failure();
@ -899,7 +898,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
return failure(); return failure();
} }
if (op.getDevice().getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(op.getDevice().getType())) {
// The device arg is `none`. Rewrite to to.dtype. // The device arg is `none`. Rewrite to to.dtype.
AtenToDtypeOp toDtype = rewriter.create<AtenToDtypeOp>( AtenToDtypeOp toDtype = rewriter.create<AtenToDtypeOp>(
op.getLoc(), op.getType(), op.getSelf(), op.getDtype(), op.getLoc(), op.getType(), op.getSelf(), op.getDtype(),
@ -985,10 +984,10 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>(); auto inputType = dyn_cast<BaseTensorType>(getOperand(0).getType());
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
return nullptr; return nullptr;
auto resType = getType().dyn_cast<BaseTensorType>(); auto resType = dyn_cast<BaseTensorType>(getType());
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1) if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
return nullptr; return nullptr;
if (inputType != resType) if (inputType != resType)
@ -1011,7 +1010,7 @@ OpFoldResult PrimsViewOfOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) { if (auto tensorType = dyn_cast<BaseTensorType>(getOperand().getType())) {
if (tensorType.hasSizes()) if (tensorType.hasSizes())
return IntegerAttr::get(IntegerType::get(getContext(), 64), return IntegerAttr::get(IntegerType::get(getContext(), 64),
tensorType.getSizes().size()); tensorType.getSizes().size());
@ -1117,7 +1116,7 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
} }
if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) { if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) {
if (op->getOperand(2).getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(op->getOperand(2).getType())) {
// None rounding mode // None rounding mode
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs); Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType, rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
@ -1879,9 +1878,9 @@ OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
auto resultType = getType().dyn_cast<ValueTensorType>(); auto resultType = dyn_cast<ValueTensorType>(getType());
if (resultType && resultType.hasDtype() && if (resultType && resultType.hasDtype() &&
resultType.getDtype().isa<mlir::IntegerType>()) { isa<mlir::IntegerType>(resultType.getDtype())) {
return getSelf(); return getSelf();
} }
return {}; return {};
@ -1892,9 +1891,9 @@ OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
auto resultType = getType().dyn_cast<ValueTensorType>(); auto resultType = dyn_cast<ValueTensorType>(getType());
if (resultType && resultType.hasDtype() && if (resultType && resultType.hasDtype() &&
resultType.getDtype().isa<mlir::IntegerType>()) { isa<mlir::IntegerType>(resultType.getDtype())) {
return getSelf(); return getSelf();
} }
return {}; return {};
@ -1905,9 +1904,9 @@ OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
auto resultType = getType().dyn_cast<ValueTensorType>(); auto resultType = dyn_cast<ValueTensorType>(getType());
if (resultType && resultType.hasDtype() && if (resultType && resultType.hasDtype() &&
resultType.getDtype().isa<mlir::IntegerType>()) { isa<mlir::IntegerType>(resultType.getDtype())) {
return getSelf(); return getSelf();
} }
return {}; return {};
@ -1918,7 +1917,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
auto resultType = getType().dyn_cast<ValueTensorType>(); auto resultType = dyn_cast<ValueTensorType>(getType());
if (resultType && resultType.hasDtype() && if (resultType && resultType.hasDtype() &&
resultType.getDtype().isa<mlir::IntegerType>()) { resultType.getDtype().isa<mlir::IntegerType>()) {
return getSelf(); return getSelf();
@ -1987,7 +1986,7 @@ void AtenDivScalarModeOp::getCanonicalizationPatterns(
void AtenNumelOp::getCanonicalizationPatterns(RewritePatternSet &patterns, void AtenNumelOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { MLIRContext *context) {
patterns.add(+[](AtenNumelOp op, PatternRewriter &rewriter) { patterns.add(+[](AtenNumelOp op, PatternRewriter &rewriter) {
auto inputType = op.getSelf().getType().dyn_cast<BaseTensorType>(); auto inputType = dyn_cast<BaseTensorType>(op.getSelf().getType());
if (!inputType || !inputType.areAllSizesKnown()) { if (!inputType || !inputType.areAllSizesKnown()) {
return failure(); return failure();
} }
@ -2113,7 +2112,7 @@ traceKnownSizeTensorType(Value value, std::optional<int64_t> dim) {
if (!value || !value.getType().isa<BaseTensorType>()) if (!value || !value.getType().isa<BaseTensorType>())
return failure(); return failure();
auto tensorType = value.getType().cast<BaseTensorType>(); auto tensorType = cast<BaseTensorType>(value.getType());
if (foundType(tensorType, dim)) if (foundType(tensorType, dim))
return tensorType; return tensorType;
@ -2649,7 +2648,7 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
.dyn_cast_or_null<ElementsAttr>(); .dyn_cast_or_null<ElementsAttr>();
if (!attr) if (!attr)
return failure(); return failure();
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>(); RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
NonValueTensorType returnType = NonValueTensorType returnType =
NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(), NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
tensorType.getElementType()); tensorType.getElementType());
@ -2691,7 +2690,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
.dyn_cast_or_null<ElementsAttr>(); .dyn_cast_or_null<ElementsAttr>();
if (!attr) if (!attr)
return failure(); return failure();
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>(); RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
ValueTensorType returnType = ValueTensorType returnType =
ValueTensorType::get(tensorType.getContext(), tensorType.getShape(), ValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
tensorType.getElementType()); tensorType.getElementType());
@ -2751,8 +2750,8 @@ void TensorStaticInfoCastOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult CopyToNonValueTensorOp::verify() { LogicalResult CopyToNonValueTensorOp::verify() {
auto resultType = getResult().getType().cast<BaseTensorType>(); auto resultType = cast<BaseTensorType>(getResult().getType());
auto operandType = getOperand().getType().cast<BaseTensorType>(); auto operandType = cast<BaseTensorType>(getOperand().getType());
if (!resultType.hasSameSizesAndDtype(operandType)) if (!resultType.hasSameSizesAndDtype(operandType))
return emitError() << "operand and result must have same sizes and dtype"; return emitError() << "operand and result must have same sizes and dtype";
return success(); return success();
@ -2762,7 +2761,7 @@ LogicalResult CopyToNonValueTensorOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands, MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) { SmallVectorImpl<Type> &inferredReturnTypes) {
auto resultType = operands[0].getType().cast<ValueTensorType>(); auto resultType = cast<ValueTensorType>(operands[0].getType());
inferredReturnTypes.push_back(resultType.getWithoutValueSemantics()); inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
return success(); return success();
} }
@ -2778,8 +2777,8 @@ void CopyToNonValueTensorOp::getEffects(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult CopyToValueTensorOp::verify() { LogicalResult CopyToValueTensorOp::verify() {
auto resultType = getResult().getType().cast<BaseTensorType>(); auto resultType = cast<BaseTensorType>(getResult().getType());
auto operandType = getOperand().getType().cast<BaseTensorType>(); auto operandType = cast<BaseTensorType>(getOperand().getType());
if (!resultType.hasSameSizesAndDtype(operandType)) if (!resultType.hasSameSizesAndDtype(operandType))
return emitError() << "operand and result must have same sizes and dtype"; return emitError() << "operand and result must have same sizes and dtype";
return success(); return success();
@ -2789,7 +2788,7 @@ LogicalResult CopyToValueTensorOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands, MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) { SmallVectorImpl<Type> &inferredReturnTypes) {
auto resultType = operands[0].getType().cast<NonValueTensorType>(); auto resultType = cast<NonValueTensorType>(operands[0].getType());
inferredReturnTypes.push_back(resultType.getWithValueSemantics()); inferredReturnTypes.push_back(resultType.getWithValueSemantics());
return success(); return success();
} }
@ -3004,7 +3003,7 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
auto operandType = getSelf().getType().dyn_cast<BaseTensorType>(); auto operandType = dyn_cast<BaseTensorType>(getSelf().getType());
if (!operandType) if (!operandType)
return nullptr; return nullptr;
if (operandType.hasDtype()) { if (operandType.hasDtype()) {
@ -3493,8 +3492,8 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>(); auto inType = dyn_cast<BaseTensorType>(getOperand(0).getType());
auto outType = getResult().getType().dyn_cast<BaseTensorType>(); auto outType = dyn_cast<BaseTensorType>(getResult().getType());
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
!outType.hasDtype()) !outType.hasDtype())
return nullptr; return nullptr;
@ -3534,8 +3533,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd()); IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep()); IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim()); IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>(); auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
auto outType = getResult().getType().dyn_cast<ValueTensorType>(); auto outType = dyn_cast<ValueTensorType>(getResult().getType());
if (start && end && step && step.getValue().getSExtValue() == 1 && if (start && end && step && step.getValue().getSExtValue() == 1 &&
start.getValue().getSExtValue() == 0 && start.getValue().getSExtValue() == 0 &&
@ -3793,7 +3792,7 @@ OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) { OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
BaseTensorType tensorType = getA().getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(getA().getType());
if (tensorType.hasDtype()) { if (tensorType.hasDtype()) {
torch_upstream::ScalarType scalarType = torch_upstream::ScalarType scalarType =
Torch::getScalarTypeForType(tensorType.getDtype()); Torch::getScalarTypeForType(tensorType.getDtype());
@ -4568,7 +4567,7 @@ LogicalResult AtenNormScalarOp::verify() {
// Per PyTorch docs, only float and complex types are valid for norm // Per PyTorch docs, only float and complex types are valid for norm
// operation. // operation.
auto inTensor = getSelf().getType().cast<BaseTensorType>(); auto inTensor = cast<BaseTensorType>(getSelf().getType());
// If no dtype is specified, it will default to a float one. // If no dtype is specified, it will default to a float one.
if (!inTensor.hasDtype()) { if (!inTensor.hasDtype()) {
@ -4605,8 +4604,8 @@ LogicalResult AtenPermuteOp::verify() {
return success(); return success();
} }
auto outType = getResult().getType().cast<BaseTensorType>(); auto outType = cast<BaseTensorType>(getResult().getType());
auto inType = getSelf().getType().cast<BaseTensorType>(); auto inType = cast<BaseTensorType>(getSelf().getType());
if (!outType.hasSizes() || !inType.hasSizes()) { if (!outType.hasSizes() || !inType.hasSizes()) {
return success(); return success();
@ -4689,8 +4688,8 @@ LogicalResult AtenPermuteOp::verify() {
LogicalResult AtenLinalgCrossOp::verify() { LogicalResult AtenLinalgCrossOp::verify() {
auto selfType = getSelf().getType().cast<BaseTensorType>(); auto selfType = cast<BaseTensorType>(getSelf().getType());
auto otherType = getOther().getType().cast<BaseTensorType>(); auto otherType = cast<BaseTensorType>(getOther().getType());
if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() || if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() ||
!otherType.hasSizes()) { !otherType.hasSizes()) {
@ -4857,7 +4856,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
// Check that initial values satisfy type bounds. // Check that initial values satisfy type bounds.
for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) { for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) {
auto symName = initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>(); auto symName = cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
auto initialValue = initialize.getOperand(i); auto initialValue = initialize.getOperand(i);
auto globalSlotOp = symbolTable.lookup<GlobalSlotOp>(symName.getValue()); auto globalSlotOp = symbolTable.lookup<GlobalSlotOp>(symName.getValue());
if (!isValidSubtype(initialValue.getType(), globalSlotOp.getTypeBound())) { if (!isValidSubtype(initialValue.getType(), globalSlotOp.getTypeBound())) {

View File

@ -49,7 +49,7 @@ public:
// The incoporation of the torch.type_bound arg attr is context-dependent. // The incoporation of the torch.type_bound arg attr is context-dependent.
for (auto type : llvm::enumerate(func.getArgumentTypes())) { for (auto type : llvm::enumerate(func.getArgumentTypes())) {
if (type.value().isa<NonValueTensorType>()) { if (isa<NonValueTensorType>(type.value())) {
auto typeBoundAttr = auto typeBoundAttr =
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent); func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type(); Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type();
@ -61,7 +61,7 @@ public:
? typeBoundAttr.getValue() ? typeBoundAttr.getValue()
: type.value()); : type.value());
continue; continue;
} else if (auto none = type.value().dyn_cast<Torch::NoneType>()) { } else if (auto none = dyn_cast<Torch::NoneType>(type.value())) {
continue; continue;
} }
// TODO: add tuple type. // TODO: add tuple type.
@ -111,7 +111,7 @@ public:
SmallVector<Value> newOperands; SmallVector<Value> newOperands;
for (auto operand : llvm::enumerate(adaptor.getOperands())) { for (auto operand : llvm::enumerate(adaptor.getOperands())) {
if (operand.value().getType().isa<Torch::NoneType>()) if (isa<Torch::NoneType>(operand.value().getType()))
continue; continue;
auto it = typeBoundMap.find({call.getCallee(), operand.index()}); auto it = typeBoundMap.find({call.getCallee(), operand.index()});
if (it != typeBoundMap.end()) { if (it != typeBoundMap.end()) {
@ -167,9 +167,9 @@ public:
for (auto operand : adaptor.getOperands()) { for (auto operand : adaptor.getOperands()) {
if (!operand) if (!operand)
continue; continue;
if (operand.getType().isa<Torch::NoneType>()) if (isa<Torch::NoneType>(operand.getType()))
continue; continue;
if (auto tuple = operand.getType().dyn_cast<Torch::TupleType>()) { if (auto tuple = dyn_cast<Torch::TupleType>(operand.getType())) {
Location loc = op.getLoc(); Location loc = op.getLoc();
for (auto en : llvm::enumerate(tuple.getContainedTypes())) { for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
auto i = rewriter.create<ConstantIntOp>( auto i = rewriter.create<ConstantIntOp>(
@ -207,7 +207,7 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
[](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs, [](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs,
Location loc) -> Value { Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<BaseTensorType>()); assert(isa<BaseTensorType>(inputs[0].getType()));
return copyTensorToType(builder, loc, type, inputs[0]); return copyTensorToType(builder, loc, type, inputs[0]);
}); });
patterns.add<AdjustCallingConventionForFunc>(typeConverter, context); patterns.add<AdjustCallingConventionForFunc>(typeConverter, context);

View File

@ -29,7 +29,7 @@ using namespace mlir::torch::Torch;
// Helper function to check whether the `dtype` is None or Float type. // Helper function to check whether the `dtype` is None or Float type.
static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
if (dtype.getType().isa<Torch::NoneType>()) if (isa<Torch::NoneType>(dtype.getType()))
return true; return true;
int64_t dtypeInt; int64_t dtypeInt;
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
@ -87,7 +87,7 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim); Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
Value dtype = rewriter.create<ConstantNoneOp>(loc); Value dtype = rewriter.create<ConstantNoneOp>(loc);
Type resultType = computeReductionType( Type resultType = computeReductionType(
rewriter, op, input.getType().cast<BaseTensorType>(), dim, keepDim); rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim);
if (!resultType) if (!resultType)
return nullptr; return nullptr;
return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList, return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList,
@ -100,7 +100,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
bool keepDim) { bool keepDim) {
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim); Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
BaseTensorType valueType = BaseTensorType valueType =
computeReductionType(rewriter, op, input.getType().cast<BaseTensorType>(), computeReductionType(rewriter, op, cast<BaseTensorType>(input.getType()),
dim, keepDim) dim, keepDim)
.cast<BaseTensorType>(); .cast<BaseTensorType>();
if (!valueType) if (!valueType)
@ -296,7 +296,7 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
int64_t contractingDimsLength, int64_t contractingDimsLength,
int64_t otherDimsLength, int64_t otherDimsLength,
int64_t reduceDimsLength, bool isLhs) { int64_t reduceDimsLength, bool isLhs) {
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength + auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
reduceDimsLength; reduceDimsLength;
SmallVector<Value> inputShapeTensor; SmallVector<Value> inputShapeTensor;
@ -415,7 +415,7 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
SmallVector<char> &contractingDims, SmallVector<char> &contractingDims,
SmallVector<char> &otherDims, SmallVector<char> &otherDims,
SmallVector<char> &reduceDims, bool isLhs) { SmallVector<char> &reduceDims, bool isLhs) {
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
llvm::SmallDenseMap<char, int64_t> dimTokenMap; llvm::SmallDenseMap<char, int64_t> dimTokenMap;
for (size_t idx = 0; idx < dimTokens.size(); ++idx) { for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
dimTokenMap[dimTokens[idx]] = idx; dimTokenMap[dimTokens[idx]] = idx;
@ -451,8 +451,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
Value &result, Value &result,
SmallVector<char> &resultTokens, SmallVector<char> &resultTokens,
SmallVector<char> &finalResultTokens) { SmallVector<char> &finalResultTokens) {
auto lhsType = lhs.getType().cast<BaseTensorType>(); auto lhsType = cast<BaseTensorType>(lhs.getType());
auto rhsType = rhs.getType().cast<BaseTensorType>(); auto rhsType = cast<BaseTensorType>(rhs.getType());
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
: rhsType.getOptionalDtype(); : rhsType.getOptionalDtype();
@ -562,7 +562,7 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter,
Value input, Value input,
SmallVector<char> &inputTokens, SmallVector<char> &inputTokens,
SmallVector<char> &outTokens) { SmallVector<char> &outTokens) {
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
llvm::SmallDenseSet<char> outTokenSet(outTokens.begin(), outTokens.end()); llvm::SmallDenseSet<char> outTokenSet(outTokens.begin(), outTokens.end());
SmallVector<int64_t> sumDims; SmallVector<int64_t> sumDims;
@ -643,7 +643,7 @@ public:
op, "Expected a constant boolean value for keepDim"); op, "Expected a constant boolean value for keepDim");
Value input = op.getSelf(); Value input = op.getSelf();
auto inputTy = input.getType().dyn_cast<Torch::ValueTensorType>(); auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes()) { if (!inputTy || !inputTy.hasSizes()) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Expected input type having sizes"); "Expected input type having sizes");
@ -677,7 +677,7 @@ public:
MLIRContext *context = op.getContext(); MLIRContext *context = op.getContext();
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes() || !inputType.hasDtype()) { if (!inputType.hasSizes() || !inputType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "should have shape and dtype"); return rewriter.notifyMatchFailure(op, "should have shape and dtype");
} }
@ -764,7 +764,7 @@ public:
Value dim = op.getDim(); Value dim = op.getDim();
Value self = op.getSelf(); Value self = op.getSelf();
auto resultTy = op.getType().cast<BaseTensorType>(); auto resultTy = cast<BaseTensorType>(op.getType());
if (!resultTy.hasSizes() || !resultTy.hasDtype()) { if (!resultTy.hasSizes() || !resultTy.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected result type to have sizes and dtype"); op, "expected result type to have sizes and dtype");
@ -785,8 +785,8 @@ public:
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one); rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
Value slice = rewriter.create<AtenSliceTensorOp>( Value slice = rewriter.create<AtenSliceTensorOp>(
loc, loc,
computeReductionType(rewriter, op, computeReductionType(rewriter, op, cast<BaseTensorType>(self.getType()),
self.getType().cast<BaseTensorType>(), dim, dim,
/*keepDim=*/true), /*keepDim=*/true),
op.getSelf(), dim, start, startPlusOne, /*step=*/one); op.getSelf(), dim, start, startPlusOne, /*step=*/one);
@ -988,7 +988,7 @@ public:
Value self = op.getSelf(); Value self = op.getSelf();
Value dim = op.getDim(); Value dim = op.getDim();
auto outputTy = op.getType().dyn_cast<Torch::ValueTensorType>(); auto outputTy = dyn_cast<Torch::ValueTensorType>(op.getType());
if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) { if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Expected output type having sizes and dtype"); op, "Expected output type having sizes and dtype");
@ -1069,7 +1069,7 @@ public:
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"unimplemented: m must be constant"); "unimplemented: m must be constant");
Value none = rewriter.create<ConstantNoneOp>(loc); Value none = rewriter.create<ConstantNoneOp>(loc);
auto outType = op.getType().dyn_cast<BaseTensorType>(); auto outType = dyn_cast<BaseTensorType>(op.getType());
if (!outType) if (!outType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
@ -1111,13 +1111,13 @@ public:
// compare unsqueezed input with boundaries // compare unsqueezed input with boundaries
auto eqType = ValueTensorType::get( auto eqType = ValueTensorType::get(
context, op.getType().cast<BaseTensorType>().getSizes(), context, cast<BaseTensorType>(op.getType()).getSizes(),
IntegerType::get(context, 1)); IntegerType::get(context, 1));
Value eqTensor = Value eqTensor =
rewriter.create<AtenEqTensorOp>(loc, eqType, unsqzRangeN, rangeM); rewriter.create<AtenEqTensorOp>(loc, eqType, unsqzRangeN, rangeM);
Value dtype = op.getDtype(); Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::BoolType>()) { if (isa<Torch::BoolType>(dtype.getType())) {
rewriter.replaceOp(op, eqTensor); rewriter.replaceOp(op, eqTensor);
return success(); return success();
} else { } else {
@ -1210,7 +1210,7 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value input = op.getSelf(); Value input = op.getSelf();
// TODO: Handle non value tensor type operands. // TODO: Handle non value tensor type operands.
if (!input.getType().isa<ValueTensorType>()) { if (!isa<ValueTensorType>(input.getType())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: only value tensor type operands are supported"); op, "unimplemented: only value tensor type operands are supported");
} }
@ -1248,7 +1248,7 @@ public:
} }
auto allTensorHasSizes = [](Value tensor) { auto allTensorHasSizes = [](Value tensor) {
auto type = tensor.getType().dyn_cast<BaseTensorType>(); auto type = dyn_cast<BaseTensorType>(tensor.getType());
if (!type || !type.hasSizes()) if (!type || !type.hasSizes())
return false; return false;
return true; return true;
@ -1267,7 +1267,7 @@ public:
if (equation.find("...") != std::string::npos) { if (equation.find("...") != std::string::npos) {
SmallVector<int64_t> inputRanks; SmallVector<int64_t> inputRanks;
for (Value tensor : inputTensors) { for (Value tensor : inputTensors) {
auto type = tensor.getType().cast<BaseTensorType>(); auto type = cast<BaseTensorType>(tensor.getType());
inputRanks.push_back(type.getSizes().size()); inputRanks.push_back(type.getSizes().size());
} }
@ -1332,10 +1332,10 @@ public:
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0)); rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value one = Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)); rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
BaseTensorType inputType = self.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(self.getType());
Value output = op.getResult(); Value output = op.getResult();
BaseTensorType outputType = output.getType().cast<BaseTensorType>(); BaseTensorType outputType = cast<BaseTensorType>(output.getType());
ArrayRef<int64_t> inputShape = inputType.getSizes(); ArrayRef<int64_t> inputShape = inputType.getSizes();
int64_t diagonalSize = std::min(inputShape[0], inputShape[1]); int64_t diagonalSize = std::min(inputShape[0], inputShape[1]);
@ -1399,7 +1399,7 @@ public:
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op, LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value self = op.getSelf(); Value self = op.getSelf();
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>(); BaseTensorType resultTensorType = cast<BaseTensorType>(op.getType());
if (!resultTensorType.hasDtype()) { if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype"); op, "expected result type to have a dtype");
@ -1410,7 +1410,7 @@ public:
"Only support floating-point type"); "Only support floating-point type");
// If `dtype` arg is non-none then convert the input to `dtype`. // If `dtype` arg is non-none then convert the input to `dtype`.
if (!op.getDtype().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getDtype().getType())) {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value none = rewriter.create<ConstantNoneOp>(loc); Value none = rewriter.create<ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false); Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
@ -1440,15 +1440,15 @@ public:
LogicalResult matchAndRewrite(Aten_SoftmaxOp op, LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value self = op.getSelf(); Value self = op.getSelf();
BaseTensorType tensorType = self.getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(self.getType());
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>()) if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
return rewriter.notifyMatchFailure(op, "Only support floating type"); return rewriter.notifyMatchFailure(op, "Only support floating type");
bool halfToFloat; bool halfToFloat;
if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat))) if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Expected a boolean value for half_to_float"); op, "Expected a boolean value for half_to_float");
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>(); BaseTensorType resultTensorType = cast<BaseTensorType>(op.getType());
if (!resultTensorType.hasDtype()) { if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype"); op, "expected result type to have a dtype");
@ -1500,8 +1500,8 @@ public:
Value output = op.getOutput(); Value output = op.getOutput();
Value dim = op.getDim(); Value dim = op.getDim();
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(gradOutput.getType());
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>()) if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
return rewriter.notifyMatchFailure(op, "Only support floating type"); return rewriter.notifyMatchFailure(op, "Only support floating type");
Value newGrad = Value newGrad =
@ -1536,8 +1536,8 @@ public:
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2). // Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
Value output = op.getOutput(); Value output = op.getOutput();
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(gradOutput.getType());
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>()) if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
return rewriter.notifyMatchFailure(op, "Only support floating type"); return rewriter.notifyMatchFailure(op, "Only support floating type");
Value tanhSquare = Value tanhSquare =
@ -1567,8 +1567,8 @@ public:
Value output = op.getOutput(); Value output = op.getOutput();
Value dim = op.getDim(); Value dim = op.getDim();
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(gradOutput.getType());
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>()) if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
return rewriter.notifyMatchFailure(op, "Only support floating type"); return rewriter.notifyMatchFailure(op, "Only support floating type");
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output); Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
@ -1650,8 +1650,8 @@ public:
Value keepDim = op.getKeepdim(); Value keepDim = op.getKeepdim();
Value result = op.getResult(); Value result = op.getResult();
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>(); BaseTensorType indicesTensorType = cast<BaseTensorType>(result.getType());
std::optional<unsigned> maybeInputRank = getTensorRank(input); std::optional<unsigned> maybeInputRank = getTensorRank(input);
if (!maybeInputRank) { if (!maybeInputRank) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1670,7 +1670,7 @@ public:
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
// first the input tensor is flattened to 1d tensor and then the reduction // first the input tensor is flattened to 1d tensor and then the reduction
// happens on the 0th dimension. // happens on the 0th dimension.
if (dim.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(dim.getType())) {
BaseTensorType flattenType = BaseTensorType flattenType =
inputType inputType
.getWithSizesAndDtype({kUnknownSize}, .getWithSizesAndDtype({kUnknownSize},
@ -1720,7 +1720,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes()) { if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: input must have known sizes"); op, "unimplemented: input must have known sizes");
@ -1728,7 +1728,7 @@ public:
ArrayRef<int64_t> inputShape = inputType.getSizes(); ArrayRef<int64_t> inputShape = inputType.getSizes();
Value boundaries = op.getBoundaries(); Value boundaries = op.getBoundaries();
auto boundariesType = boundaries.getType().cast<BaseTensorType>(); auto boundariesType = cast<BaseTensorType>(boundaries.getType());
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) { if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"unimplemented: boundaries must have " "unimplemented: boundaries must have "
@ -1827,7 +1827,7 @@ static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value dim = op.getDim(); Value dim = op.getDim();
Value self = op.getSelf(); Value self = op.getSelf();
BaseTensorType tensorType = self.getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(self.getType());
Value xMax = Value xMax =
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true); createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
if (!xMax) if (!xMax)
@ -1856,12 +1856,12 @@ public:
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op, LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value self = op.getSelf(); Value self = op.getSelf();
if (!op.getDtype().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(op.getDtype().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented non-None dtype for log_softmax"); op, "Unimplemented non-None dtype for log_softmax");
BaseTensorType tensorType = self.getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(self.getType());
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>()) if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
return rewriter.notifyMatchFailure(op, "Only support floating type"); return rewriter.notifyMatchFailure(op, "Only support floating type");
Value logSoftmax = getLogSoftmaxResult(op, rewriter); Value logSoftmax = getLogSoftmaxResult(op, rewriter);
@ -1974,7 +1974,7 @@ public:
Type opType = op.getType(); Type opType = op.getType();
Value dim = op.getDim(); Value dim = op.getDim();
auto resType = self.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(self.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -2088,7 +2088,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value inValue = op.getSelf(); Value inValue = op.getSelf();
auto inType = inValue.getType().cast<BaseTensorType>(); auto inType = cast<BaseTensorType>(inValue.getType());
auto maybeSizes = inType.getOptionalSizes(); auto maybeSizes = inType.getOptionalSizes();
if (!maybeSizes) { if (!maybeSizes) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -2234,7 +2234,7 @@ public:
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
static Value getRelu6Results(PatternRewriter &rewriter, Location loc, static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
Value input) { Value input) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input); Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
Value cst6 = Value cst6 =
@ -2252,7 +2252,7 @@ public:
LogicalResult matchAndRewrite(AtenRelu6Op op, LogicalResult matchAndRewrite(AtenRelu6Op op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -2304,7 +2304,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
Value negativeSlope = op.getNegativeSlope(); Value negativeSlope = op.getNegativeSlope();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -2341,7 +2341,7 @@ public:
Value gradOutput = op.getGradOutput(); Value gradOutput = op.getGradOutput();
Value input = op.getSelf(); Value input = op.getSelf();
Value negativeSlope = op.getNegativeSlope(); Value negativeSlope = op.getNegativeSlope();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -2382,7 +2382,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
Value weight = op.getWeight(); Value weight = op.getWeight();
auto resType = op.getType().cast<ValueTensorType>(); auto resType = cast<ValueTensorType>(op.getType());
auto boolTensorType = rewriter.getType<ValueTensorType>( auto boolTensorType = rewriter.getType<ValueTensorType>(
resType.getOptionalSizes(), rewriter.getI1Type()); resType.getOptionalSizes(), rewriter.getI1Type());
Value zero = Value zero =
@ -2408,14 +2408,14 @@ public:
LogicalResult matchAndRewrite(AtenLerpScalarOp op, LogicalResult matchAndRewrite(AtenLerpScalarOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
Value cstOne = Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)); rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto start = op.getSelf(); auto start = op.getSelf();
auto inputType = start.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(start.getType());
auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(), auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(),
start, cstOne); start, cstOne);
@ -2442,7 +2442,7 @@ public:
Value alpha = op.getAlpha(); Value alpha = op.getAlpha();
Value scale = op.getScale(); Value scale = op.getScale();
Value inputScale = op.getInputScale(); Value inputScale = op.getInputScale();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -2486,7 +2486,7 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -2578,7 +2578,7 @@ public:
} }
// Ensure all tensors have known sizes // Ensure all tensors have known sizes
for (Value tensor : tensors) { for (Value tensor : tensors) {
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
if (!tensorType.hasSizes()) { if (!tensorType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: one tensor does not have known sizes"); op, "unimplemented: one tensor does not have known sizes");
@ -2596,7 +2596,8 @@ public:
} }
Type listElemType = Type listElemType =
op.getType().cast<BaseTensorType>().getWithSizesAndDtype( cast<BaseTensorType>(op.getType())
.getWithSizesAndDtype(
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType); Type listType = Torch::ListType::get(listElemType);
Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>( Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>(
@ -2635,7 +2636,7 @@ public:
Value constOne = rewriter.create<Torch::ConstantIntOp>( Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1)); loc, rewriter.getI64IntegerAttr(1));
auto self = op.getSelf(); auto self = op.getSelf();
auto selfTy = self.getType().cast<BaseTensorType>(); auto selfTy = cast<BaseTensorType>(self.getType());
// roll(input, shift, dim) = cat({ // roll(input, shift, dim) = cat({
// slice(input, dim, -shift, none), // slice(input, dim, -shift, none),
// slice(input, dim, 0, -shift)}, dim) // slice(input, dim, 0, -shift)}, dim)
@ -2817,7 +2818,7 @@ public:
if (!selfTy.hasSizes()) if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor"); op, "Unimplemented: no implementation for rankless tensor");
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasSizes()) if (!resType.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor"); op, "Unimplemented: no implementation for rankless tensor");
@ -2968,7 +2969,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value self = op.getSelf(); Value self = op.getSelf();
MLIRContext *context = op.getContext(); MLIRContext *context = op.getContext();
BaseTensorType outputTensorType = op.getType().cast<BaseTensorType>(); BaseTensorType outputTensorType = cast<BaseTensorType>(op.getType());
if (!outputTensorType.hasSizes()) if (!outputTensorType.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: output must have known sizes"); op, "unimplemented: output must have known sizes");
@ -2977,7 +2978,7 @@ public:
if (!maybeRank) if (!maybeRank)
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
unsigned inputRank = *maybeRank; unsigned inputRank = *maybeRank;
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>(); auto inputTensorType = cast<Torch::ValueTensorType>(self.getType());
if (!inputTensorType || !inputTensorType.hasSizes()) { if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Expected input type having sizes"); "Expected input type having sizes");
@ -3077,7 +3078,7 @@ public:
LogicalResult matchAndRewrite(AtenWhereScalarOp op, LogicalResult matchAndRewrite(AtenWhereScalarOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -3100,7 +3101,7 @@ public:
LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op, LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -3122,7 +3123,7 @@ public:
LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op, LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -3186,7 +3187,7 @@ public:
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op, LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -3227,7 +3228,7 @@ static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter,
int64_t dimB, int64_t dimB,
Value &transposed) { Value &transposed) {
Type transposedType; Type transposedType;
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(), if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
dimA, dimB, transposedType))) dimA, dimB, transposedType)))
return failure(); return failure();
Value cstDimA = rewriter.create<Torch::ConstantIntOp>( Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
@ -3578,7 +3579,7 @@ public:
op.getGroups(), op.getDilation()); op.getGroups(), op.getDilation());
Type transposedType; Type transposedType;
if (failed(getTransposedType(input.getType().cast<BaseTensorType>(), 0, 1, if (failed(getTransposedType(cast<BaseTensorType>(input.getType()), 0, 1,
transposedType))) transposedType)))
return failure(); return failure();
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>( Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
@ -3605,7 +3606,7 @@ public:
ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2], ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2],
gradOutputSize[3]}); gradOutputSize[3]});
BaseTensorType gradOutputTy = gradOutput.getType().cast<BaseTensorType>(); BaseTensorType gradOutputTy = cast<BaseTensorType>(gradOutput.getType());
if (!gradOutputTy.hasSizes()) if (!gradOutputTy.hasSizes())
return failure(); return failure();
SmallVector<int64_t> gradOutputSizesInt(gradOutputTy.getSizes()); SmallVector<int64_t> gradOutputSizesInt(gradOutputTy.getSizes());
@ -3625,7 +3626,7 @@ public:
loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList); loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList);
BaseTensorType inputTransposedTy = BaseTensorType inputTransposedTy =
inputTransposed.getType().cast<BaseTensorType>(); cast<BaseTensorType>(inputTransposed.getType());
if (!inputTransposedTy.hasSizes()) if (!inputTransposedTy.hasSizes())
return failure(); return failure();
SmallVector<int64_t> inputTransposedSizesInt( SmallVector<int64_t> inputTransposedSizesInt(
@ -3660,7 +3661,7 @@ public:
/*dilation=*/op.getStride(), op.getTransposed(), /*dilation=*/op.getStride(), op.getTransposed(),
op.getOutputPadding(), numGroup); op.getOutputPadding(), numGroup);
BaseTensorType weightTy = weight.getType().cast<BaseTensorType>(); BaseTensorType weightTy = cast<BaseTensorType>(weight.getType());
if (!weightTy.hasSizes()) if (!weightTy.hasSizes())
return failure(); return failure();
SmallVector<int64_t> weightSizes(weightTy.getSizes()); SmallVector<int64_t> weightSizes(weightTy.getSizes());
@ -3707,7 +3708,7 @@ public:
gradWeight = rewriter.create<Torch::AtenViewOp>( gradWeight = rewriter.create<Torch::AtenViewOp>(
loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList); loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList);
gradWeightTy = gradWeight.getType().cast<BaseTensorType>(); gradWeightTy = cast<BaseTensorType>(gradWeight.getType());
SmallVector<int64_t, 5> gradWeightDimsOrder = SmallVector<int64_t, 5> gradWeightDimsOrder =
computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size()); computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size());
SmallVector<int64_t, 5> gradWeightMoveDimShape; SmallVector<int64_t, 5> gradWeightMoveDimShape;
@ -3733,7 +3734,7 @@ public:
/*keepdim=*/cstFalse, /*keepdim=*/cstFalse,
/*dtype=*/cstNone); /*dtype=*/cstNone);
} else { } else {
if (failed(getTransposedType(gradOutput.getType().cast<BaseTensorType>(), if (failed(getTransposedType(cast<BaseTensorType>(gradOutput.getType()),
0, 1, transposedType))) 0, 1, transposedType)))
return failure(); return failure();
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>( Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
@ -3792,7 +3793,7 @@ public:
} }
// TODO: Handle integer type operands. // TODO: Handle integer type operands.
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) { if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: non-floating point dtype"); op, "unimplemented: non-floating point dtype");
@ -3821,7 +3822,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
Value output = op.getResult(); Value output = op.getResult();
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>(); BaseTensorType outputTensorType = cast<BaseTensorType>(output.getType());
Value sum = Value sum =
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.getDtype()); rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.getDtype());
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input); Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
@ -3854,7 +3855,7 @@ public:
Type outputType = op.getType(); Type outputType = op.getType();
MLIRContext *context = op.getContext(); MLIRContext *context = op.getContext();
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() || if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
!isNoneOrFloatDtype(context, dtype)) { !isNoneOrFloatDtype(context, dtype)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -3944,7 +3945,7 @@ public:
rewriter.replaceOp(op, input); rewriter.replaceOp(op, input);
return success(); return success();
} }
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only support floating type input for training mode"); op, "only support floating type input for training mode");
@ -3992,7 +3993,7 @@ public:
rewriter.replaceOp(op, ArrayRef<Value>{input, trueMask}); rewriter.replaceOp(op, ArrayRef<Value>{input, trueMask});
return success(); return success();
} }
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) { if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only support floating type input for training mode"); op, "only support floating type input for training mode");
@ -4029,7 +4030,7 @@ public:
return rewriter.notifyMatchFailure(op, "expected input to have a rank"); return rewriter.notifyMatchFailure(op, "expected input to have a rank");
} }
unsigned inputRank = *maybeInputRank; unsigned inputRank = *maybeInputRank;
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>(); BaseTensorType rank0FloatTensorTy = cast<BaseTensorType>(op.getType());
if (!rank0FloatTensorTy.hasSizes() || if (!rank0FloatTensorTy.hasSizes() ||
rank0FloatTensorTy.getSizes().size() != 0) { rank0FloatTensorTy.getSizes().size() != 0) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -4060,7 +4061,7 @@ public:
LogicalResult matchAndRewrite(AtenStdOp op, LogicalResult matchAndRewrite(AtenStdOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value self = op.getSelf(); Value self = op.getSelf();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>(); BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
if (!inputTensorTy.hasDtype() || if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) { !inputTensorTy.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -4084,7 +4085,7 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
Value inputTimesBeta = Value inputTimesBeta =
rewriter.create<AtenMulScalarOp>(loc, inputType, input, op.getBeta()); rewriter.create<AtenMulScalarOp>(loc, inputType, input, op.getBeta());
@ -4116,7 +4117,7 @@ public:
LogicalResult matchAndRewrite(AtenStdDimOp op, LogicalResult matchAndRewrite(AtenStdDimOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value self = op.getSelf(); Value self = op.getSelf();
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>(); BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
if (!inputTensorType.hasDtype() || if (!inputTensorType.hasDtype() ||
!inputTensorType.getDtype().isa<mlir::FloatType>()) { !inputTensorType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -4141,7 +4142,7 @@ public:
LogicalResult matchAndRewrite(AtenStdCorrectionOp op, LogicalResult matchAndRewrite(AtenStdCorrectionOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value self = op.getSelf(); Value self = op.getSelf();
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>(); BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
if (!inputTensorType.hasDtype() || if (!inputTensorType.hasDtype() ||
!inputTensorType.getDtype().isa<mlir::FloatType>()) { !inputTensorType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -4167,8 +4168,8 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -4208,8 +4209,8 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
auto resType = op.getType().cast<BaseTensorType>(); auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
@ -4235,7 +4236,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
Type resultType = op.getType(); Type resultType = op.getType();
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) { if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support floating-point type"); "only support floating-point type");
@ -4268,8 +4269,8 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
Operation *op, Location loc, Operation *op, Location loc,
Value input, Value prob, Value input, Value prob,
Value &output) { Value &output) {
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
auto probType = prob.getType().cast<BaseTensorType>(); auto probType = cast<BaseTensorType>(prob.getType());
// Both the `input` and `prob` must be ranked tensors. // Both the `input` and `prob` must be ranked tensors.
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() || if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
!probType.hasDtype()) { !probType.hasDtype()) {
@ -4338,12 +4339,12 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
Value p = op.getP(); Value p = op.getP();
if (!op.getGenerator().getType().template isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(op.getGenerator().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default " op, "The generator has to be None because only global default "
"generator is supported"); "generator is supported");
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
SmallVector<int64_t> empty; SmallVector<int64_t> empty;
Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty), Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty),
rewriter.getF64Type()); rewriter.getF64Type());
@ -4485,7 +4486,7 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto input = op.getInput().getType().cast<BaseTensorType>(); auto input = cast<BaseTensorType>(op.getInput().getType());
if (!input.hasSizes()) if (!input.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes."); op, "input tensor should have known sizes.");
@ -4518,7 +4519,7 @@ class DecomposeAtenInstanceNormOp
Location loc = op.getLoc(); Location loc = op.getLoc();
auto context = op.getContext(); auto context = op.getContext();
auto inputTy = op.getInput().getType().cast<BaseTensorType>(); auto inputTy = cast<BaseTensorType>(op.getInput().getType());
int64_t inputRank = inputTy.getSizes().size(); int64_t inputRank = inputTy.getSizes().size();
SmallVector<int64_t> reducedShape(inputTy.getSizes()); SmallVector<int64_t> reducedShape(inputTy.getSizes());
SmallVector<int64_t> reduceDimInts; SmallVector<int64_t> reduceDimInts;
@ -4583,7 +4584,7 @@ class DecomposeAtenInstanceNormOp
loc, op.getResult().getType(), inputNormalized); loc, op.getResult().getType(), inputNormalized);
Value weight = op.getWeight(); Value weight = op.getWeight();
auto weightTy = weight.getType().cast<BaseTensorType>(); auto weightTy = cast<BaseTensorType>(weight.getType());
dtype = weightTy.getOptionalDtype(); dtype = weightTy.getOptionalDtype();
SmallVector<int64_t> weightShape(weightTy.getSizes()); SmallVector<int64_t> weightShape(weightTy.getSizes());
@ -4610,7 +4611,7 @@ class DecomposeAtenInstanceNormOp
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput()); rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());
Value bias = op.getBias(); Value bias = op.getBias();
auto biasTy = bias.getType().cast<BaseTensorType>(); auto biasTy = cast<BaseTensorType>(bias.getType());
dtype = biasTy.getOptionalDtype(); dtype = biasTy.getOptionalDtype();
SmallVector<int64_t> biasShape(biasTy.getSizes()); SmallVector<int64_t> biasShape(biasTy.getSizes());
@ -4654,7 +4655,7 @@ class DecomposeAtenNativeLayerNormOp
Location loc = op.getLoc(); Location loc = op.getLoc();
auto context = op.getContext(); auto context = op.getContext();
auto inputTy = op.getInput().getType().cast<BaseTensorType>(); auto inputTy = cast<BaseTensorType>(op.getInput().getType());
if (!inputTy.hasSizes()) if (!inputTy.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes."); op, "input tensor should have known sizes.");
@ -4889,10 +4890,10 @@ class DecomposeAtenNativeGroupNormOp
Value eps = op.getEps(); Value eps = op.getEps();
// Check the rank of the input/outputs tensor. // Check the rank of the input/outputs tensor.
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
auto outputType = op.getResult0().getType().cast<BaseTensorType>(); auto outputType = cast<BaseTensorType>(op.getResult0().getType());
auto meanType = op.getResult1().getType().cast<BaseTensorType>(); auto meanType = cast<BaseTensorType>(op.getResult1().getType());
auto rsqrtVarType = op.getResult2().getType().cast<BaseTensorType>(); auto rsqrtVarType = cast<BaseTensorType>(op.getResult2().getType());
if (!inputType.hasSizes() || !outputType.hasSizes() || if (!inputType.hasSizes() || !outputType.hasSizes() ||
!meanType.hasSizes() || !rsqrtVarType.hasSizes()) { !meanType.hasSizes() || !rsqrtVarType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -5059,8 +5060,8 @@ class DecomposeAtenNativeBatchNormOp
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1); SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
runningStatsShapeInt[1] = runningStatsShapeInt[1] =
runningMean.getType().cast<BaseTensorType>().getSizes()[0]; cast<BaseTensorType>(runningMean.getType()).getSizes()[0];
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype(); Type dtype = cast<ValueTensorType>(input.getType()).getOptionalDtype();
Type reshapeType = ValueTensorType::get( Type reshapeType = ValueTensorType::get(
context, llvm::ArrayRef(runningStatsShapeInt), dtype); context, llvm::ArrayRef(runningStatsShapeInt), dtype);
@ -5175,8 +5176,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value dtype = op.getDtype(); Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) { if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType = BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
op.getSelf().getType().template cast<BaseTensorType>();
if (!tensorType.hasDtype()) { if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype"); op, "expected input tensor to have a dtype");
@ -5200,7 +5200,7 @@ public:
LogicalResult matchAndRewrite(AtenFullOp op, LogicalResult matchAndRewrite(AtenFullOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
BaseTensorType outTy = op.getType().template cast<BaseTensorType>(); BaseTensorType outTy = cast<BaseTensorType>(op.getType());
if (!outTy.hasDtype()) { if (!outTy.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype"); op, "expected result type to have a dtype");
@ -5231,12 +5231,12 @@ public:
Value weight = op.getWeight(); Value weight = op.getWeight();
Value bias = op.getBias(); Value bias = op.getBias();
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes() || inputType.getSizes().size() < 2) if (!inputType.hasSizes() || inputType.getSizes().size() < 2)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected input to be rank 2 or greater"); op, "expected input to be rank 2 or greater");
BaseTensorType weightType = weight.getType().cast<BaseTensorType>(); BaseTensorType weightType = cast<BaseTensorType>(weight.getType());
// `weight` must be a rank 2 matrix. // `weight` must be a rank 2 matrix.
if (!weightType.hasSizes() || weightType.getSizes().size() != 2) if (!weightType.hasSizes() || weightType.getSizes().size() != 2)
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2"); return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");
@ -5255,7 +5255,7 @@ public:
return success(); return success();
} }
BaseTensorType biasType = bias.getType().cast<BaseTensorType>(); BaseTensorType biasType = cast<BaseTensorType>(bias.getType());
if (!biasType.hasSizes() || biasType.getSizes().size() != 1) if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
@ -5280,7 +5280,7 @@ public:
Value input = op.getSelf(); Value input = op.getSelf();
Type type = op.getType(); Type type = op.getType();
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype()) if (!inputType.hasDtype())
return rewriter.notifyMatchFailure(op, "Dtype not present"); return rewriter.notifyMatchFailure(op, "Dtype not present");
@ -5306,7 +5306,7 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFullLikeOp op, LogicalResult matchAndRewrite(AtenFullLikeOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
BaseTensorType outTy = op.getType().template cast<BaseTensorType>(); BaseTensorType outTy = cast<BaseTensorType>(op.getType());
if (!outTy.hasDtype()) { if (!outTy.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype"); op, "expected result type to have a dtype");
@ -5335,7 +5335,7 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value dtype = op.getDtype(); Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) { if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) { if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype"); op, "expected input tensor to have a dtype");
@ -5393,7 +5393,7 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ToCopyOp op, LogicalResult matchAndRewrite(Aten_ToCopyOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto resultType = op.getType().cast<BaseTensorType>(); auto resultType = cast<BaseTensorType>(op.getType());
if (!resultType.hasDtype()) { if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype"); op, "expected result type to have a dtype");
@ -5419,12 +5419,12 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopyOp op, LogicalResult matchAndRewrite(AtenCopyOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto resultType = op.getType().cast<BaseTensorType>(); auto resultType = cast<BaseTensorType>(op.getType());
if (!resultType.hasDtype()) { if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype"); op, "expected result type to have a dtype");
} }
auto srcTy = op.getSrc().getType().cast<BaseTensorType>(); auto srcTy = cast<BaseTensorType>(op.getSrc().getType());
if (!srcTy.hasSizes() || !srcTy.hasDtype()) { if (!srcTy.hasSizes() || !srcTy.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected src type to have a known rank and dtype"); op, "expected src type to have a known rank and dtype");
@ -5448,7 +5448,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc()); Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value dtype = op.getDtype(); Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) { if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) { if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype"); op, "expected input tensor to have a dtype");
@ -5588,7 +5588,7 @@ public:
Value constNone = rewriter.create<ConstantNoneOp>(loc); Value constNone = rewriter.create<ConstantNoneOp>(loc);
Value dtype = op.getDtype(); Value dtype = op.getDtype();
if (dtype.getType().template isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(dtype.getType())) {
dtype = rewriter.create<Torch::PrimDtypeOp>(loc, op.getSelf()); dtype = rewriter.create<Torch::PrimDtypeOp>(loc, op.getSelf());
} }
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(), rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
@ -5665,7 +5665,7 @@ class DecomposeAtenAdaptiveAvgPool1dOp
SmallVector<Value, 1> kernelSize; SmallVector<Value, 1> kernelSize;
if (outputSizeInt == 1) { if (outputSizeInt == 1) {
BaseTensorType inputTensorType = input.getType().cast<BaseTensorType>(); BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
ArrayRef<int64_t> inputShape = inputTensorType.getSizes(); ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
kernelSize.push_back( kernelSize.push_back(
inputShape[rank - 1] == kUnknownSize inputShape[rank - 1] == kUnknownSize
@ -5839,7 +5839,7 @@ class DecomposeAtenCosineSimilarityOp
SmallVector<Value> indexBroadcastShapeValue; SmallVector<Value> indexBroadcastShapeValue;
computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt, computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt,
indexBroadcastShapeValue); indexBroadcastShapeValue);
Type dtype = x1.getType().cast<BaseTensorType>().getOptionalDtype(); Type dtype = cast<BaseTensorType>(x1.getType()).getOptionalDtype();
Type broadcastType = ValueTensorType::get( Type broadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype); op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype);
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>( Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
@ -5925,9 +5925,9 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern<AtenBaddbmmOp> {
Value alphaTimesBmm = Value alphaTimesBmm =
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha()); rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha());
Value input = op.getSelf(); Value input = op.getSelf();
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
BaseTensorType resultType = BaseTensorType resultType =
op->getResult(0).getType().cast<BaseTensorType>(); cast<BaseTensorType>(op->getResult(0).getType());
if (inputType.hasDtype() && resultType.hasDtype() && if (inputType.hasDtype() && resultType.hasDtype() &&
inputType.getDtype() != resultType.getDtype()) { inputType.getDtype() != resultType.getDtype()) {
input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype()); input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype());
@ -6011,7 +6011,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
Value self = op.getSelf(); Value self = op.getSelf();
Value dimList = op.getDim(); Value dimList = op.getDim();
Value keepDim = op.getKeepdim(); Value keepDim = op.getKeepdim();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>(); BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
Type outputType = op.getType(); Type outputType = op.getType();
BaseTensorType outputTensorType = cast<BaseTensorType>(outputType); BaseTensorType outputTensorType = cast<BaseTensorType>(outputType);
if (!outputTensorType.hasDtype()) { if (!outputTensorType.hasDtype()) {
@ -6030,7 +6030,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
// computation of the result. // computation of the result.
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) { if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type()); self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
inputTensorTy = self.getType().cast<BaseTensorType>(); inputTensorTy = cast<BaseTensorType>(self.getType());
} }
std::optional<unsigned> maybeInputRank = getTensorRank(self); std::optional<unsigned> maybeInputRank = getTensorRank(self);
@ -6040,7 +6040,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
unsigned inputRank = *maybeInputRank; unsigned inputRank = *maybeInputRank;
SmallVector<Value> dimListElements; SmallVector<Value> dimListElements;
bool isNoneOrEmpty = true; bool isNoneOrEmpty = true;
if (!dimList.getType().template isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(dimList.getType())) {
if (!getListConstructElements(dimList, dimListElements)) if (!getListConstructElements(dimList, dimListElements))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expect dimList to be constructed from list construct"); op, "expect dimList to be constructed from list construct");
@ -6287,8 +6287,8 @@ public:
op, "Expected a constant integer value for reduction"); op, "Expected a constant integer value for reduction");
Location loc = op.getLoc(); Location loc = op.getLoc();
BaseTensorType resultType = op.getType().cast<BaseTensorType>(); BaseTensorType resultType = cast<BaseTensorType>(op.getType());
BaseTensorType inputType = op.getSelf().getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(op.getSelf().getType());
if (!inputType.hasSizes()) if (!inputType.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Expected the input tensor to have sizes"); op, "Expected the input tensor to have sizes");
@ -6506,7 +6506,7 @@ public:
LogicalResult matchAndRewrite(AtenRandnGeneratorOp op, LogicalResult matchAndRewrite(AtenRandnGeneratorOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto resultType = op.getType().cast<BaseTensorType>(); auto resultType = cast<BaseTensorType>(op.getType());
if (!resultType.hasDtype()) { if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -6617,7 +6617,7 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto resultType = op.getType().cast<BaseTensorType>(); auto resultType = cast<BaseTensorType>(op.getType());
if (!resultType.hasDtype()) { if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -6943,7 +6943,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
auto context = op.getContext(); auto context = op.getContext();
Value input = op.getSelf(); Value input = op.getSelf();
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes()) if (!inputType.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes."); op, "input tensor should have known sizes.");
@ -6974,7 +6974,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
// compare // compare
auto eqType = ValueTensorType::get( auto eqType = ValueTensorType::get(
context, op.getType().cast<BaseTensorType>().getSizes(), context, cast<BaseTensorType>(op.getType()).getSizes(),
IntegerType::get(context, 1)); IntegerType::get(context, 1));
Value eqTensor = rewriter.create<AtenEqTensorOp>( Value eqTensor = rewriter.create<AtenEqTensorOp>(
loc, eqType, unsqueezeTensor, arangeTensor); loc, eqType, unsqueezeTensor, arangeTensor);
@ -7019,7 +7019,7 @@ public:
LogicalResult matchAndRewrite(AtenScalarTensorOp op, LogicalResult matchAndRewrite(AtenScalarTensorOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto resultTy = op.getResult().getType().cast<BaseTensorType>(); auto resultTy = cast<BaseTensorType>(op.getResult().getType());
auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType()); auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType());
Value numToTensor = rewriter.create<PrimNumToTensorScalarOp>( Value numToTensor = rewriter.create<PrimNumToTensorScalarOp>(
op.getLoc(), op.getLoc(),
@ -7060,7 +7060,7 @@ public:
Value self = op.getSelf(); Value self = op.getSelf();
Value dim = op.getDim(); Value dim = op.getDim();
auto selfType = self.getType().cast<BaseTensorType>(); auto selfType = cast<BaseTensorType>(self.getType());
auto sortIndicesType = selfType.getWithSizesAndDtype( auto sortIndicesType = selfType.getWithSizesAndDtype(
selfType.getOptionalSizes(), selfType.getOptionalSizes(),
IntegerType::get(context, 64, IntegerType::Signed)); IntegerType::get(context, 64, IntegerType::Signed));
@ -7111,8 +7111,8 @@ public:
Value sizeList = rewriter.create<PrimListConstructOp>( Value sizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), sizes); loc, ListType::get(IntType::get(context)), sizes);
auto selfType = self.getType().cast<BaseTensorType>(); auto selfType = cast<BaseTensorType>(self.getType());
auto indexType = index.getType().cast<BaseTensorType>(); auto indexType = cast<BaseTensorType>(index.getType());
BaseTensorType srcType = BaseTensorType srcType =
selfType selfType
.getWithSizesAndDtype(indexType.getOptionalSizes(), .getWithSizesAndDtype(indexType.getOptionalSizes(),
@ -7135,7 +7135,7 @@ public:
LogicalResult matchAndRewrite(AtenSgnOp op, LogicalResult matchAndRewrite(AtenSgnOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto outType = op.getType().cast<BaseTensorType>(); auto outType = cast<BaseTensorType>(op.getType());
if (!outType.hasDtype()) { if (!outType.hasDtype()) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"expected result type to have dtype"); "expected result type to have dtype");
@ -7273,14 +7273,14 @@ public:
"failed to get elements of `indices`"); "failed to get elements of `indices`");
auto input = op.getSelf(); auto input = op.getSelf();
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes()) { if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only input with shape information is supported"); op, "only input with shape information is supported");
} }
auto inputSizes = inputType.getSizes(); auto inputSizes = inputType.getSizes();
int64_t inputRank = inputSizes.size(); int64_t inputRank = inputSizes.size();
auto outputType = op.getType().cast<BaseTensorType>(); auto outputType = cast<BaseTensorType>(op.getType());
if (!outputType.hasSizes()) { if (!outputType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only output with shape information is supported"); op, "only output with shape information is supported");
@ -7438,7 +7438,7 @@ public:
op, "failed to get elements of `dims` param"); op, "failed to get elements of `dims` param");
} }
auto dimsSize = dimsElements.size(); auto dimsSize = dimsElements.size();
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes()) { if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only support input tensor with shape information"); op, "only support input tensor with shape information");

View File

@ -89,7 +89,7 @@ public:
.cast<ValueTensorType>() .cast<ValueTensorType>()
.getOptionalDtype(); .getOptionalDtype();
auto torchQType = auto torchQType =
quant.getType().cast<ValueTensorType>().getOptionalDtype(); cast<ValueTensorType>(quant.getType()).getOptionalDtype();
auto transQTy = auto transQTy =
rewriter.getType<ValueTensorType>(trans.getResult() rewriter.getType<ValueTensorType>(trans.getResult()
.getType() .getType()
@ -152,7 +152,7 @@ public:
return failure(); return failure();
Value bias = operands[2]; Value bias = operands[2];
auto biasTy = bias.getType().dyn_cast<ValueTensorType>(); auto biasTy = dyn_cast<ValueTensorType>(bias.getType());
if (biasTy) { if (biasTy) {
auto biasETy = biasTy.getOptionalDtype(); auto biasETy = biasTy.getOptionalDtype();

View File

@ -134,7 +134,7 @@ private:
slotName = setAttrOp.getName(); slotName = setAttrOp.getName();
} }
auto moduleType = module.getType().cast<NnModuleType>(); auto moduleType = cast<NnModuleType>(module.getType());
auto slots = moduleClassNameToSlots.find(moduleType.getClassName()); auto slots = moduleClassNameToSlots.find(moduleType.getClassName());
// TODO: Improve verifier so that this can never happen // TODO: Improve verifier so that this can never happen
if (slots == moduleClassNameToSlots.end()) if (slots == moduleClassNameToSlots.end())
@ -163,13 +163,13 @@ private:
} }
auto classType = symbolTable.lookup<ClassTypeOp>( auto classType = symbolTable.lookup<ClassTypeOp>(
nnModule.getType().cast<NnModuleType>().getClassName()); cast<NnModuleType>(nnModule.getType()).getClassName());
for (auto t : for (auto t :
llvm::zip(nnModule.getOps<SlotOp>(), classType.getOps<AttrOp>())) { llvm::zip(nnModule.getOps<SlotOp>(), classType.getOps<AttrOp>())) {
auto slot = std::get<0>(t); auto slot = std::get<0>(t);
auto attr = std::get<1>(t); auto attr = std::get<1>(t);
nameStack.push_back(attr.getName().str()); nameStack.push_back(attr.getName().str());
if (attr.getType().isa<NnModuleType>()) { if (isa<NnModuleType>(attr.getType())) {
if (failed(recursivelyTraverse( if (failed(recursivelyTraverse(
slot.getValue().getDefiningOp<NnModuleOp>()))) slot.getValue().getDefiningOp<NnModuleOp>())))
return failure(); return failure();
@ -333,7 +333,7 @@ static LogicalResult analyzeInstances(func::FuncOp func,
for (auto &argInstance : argInstances) for (auto &argInstance : argInstances)
mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance); mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance);
auto walkResult = func.walk([&](PrimGetAttrOp op) { auto walkResult = func.walk([&](PrimGetAttrOp op) {
if (!op.getType().isa<NnModuleType>()) if (!isa<NnModuleType>(op.getType()))
return WalkResult::advance(); return WalkResult::advance();
auto instance = mapping.lookupOrNull(op.getReceiver()); auto instance = mapping.lookupOrNull(op.getReceiver());
assert(instance && "verifyFuncConformsToSubset should ensure this"); assert(instance && "verifyFuncConformsToSubset should ensure this");
@ -355,7 +355,7 @@ createMonomorphizationForCall(func::CallOp op, IRMapping &mapping,
Monomorphization monomorphization; Monomorphization monomorphization;
monomorphization.func = func; monomorphization.func = func;
for (auto operand : llvm::enumerate(op->getOperands())) { for (auto operand : llvm::enumerate(op->getOperands())) {
if (!operand.value().getType().isa<NnModuleType>()) if (!isa<NnModuleType>(operand.value().getType()))
continue; continue;
Value instance = mapping.lookupOrNull(operand.value()); Value instance = mapping.lookupOrNull(operand.value());
assert(instance && "verifyFuncConformsToSubset should ensure this"); assert(instance && "verifyFuncConformsToSubset should ensure this");
@ -377,7 +377,7 @@ public:
monomorphization.func = func; monomorphization.func = func;
bool canTriviallyMonomorphize = true; bool canTriviallyMonomorphize = true;
for (auto arg : llvm::enumerate(func.getArguments())) { for (auto arg : llvm::enumerate(func.getArguments())) {
auto type = arg.value().getType().dyn_cast<NnModuleType>(); auto type = dyn_cast<NnModuleType>(arg.value().getType());
if (!type) if (!type)
continue; continue;
auto classType = symbolTable.lookup<ClassTypeOp>(type.getClassName()); auto classType = symbolTable.lookup<ClassTypeOp>(type.getClassName());
@ -436,7 +436,7 @@ private:
// !torch.nn.Module<"..."> types. // !torch.nn.Module<"..."> types.
static LogicalResult verifyNnModuleValueUses(Value value) { static LogicalResult verifyNnModuleValueUses(Value value) {
// Trivially succeed for non-module types. // Trivially succeed for non-module types.
if (!value.getType().isa<NnModuleType>()) if (!isa<NnModuleType>(value.getType()))
return success(); return success();
for (Operation *op : value.getUsers()) { for (Operation *op : value.getUsers()) {
if (isa<func::CallOp, PrimGetAttrOp>(op)) if (isa<func::CallOp, PrimGetAttrOp>(op))
@ -516,7 +516,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
return WalkResult::advance(); return WalkResult::advance();
}; };
auto handlePrimGetAttr = [&](PrimGetAttrOp op) { auto handlePrimGetAttr = [&](PrimGetAttrOp op) {
if (!op.getType().isa<NnModuleType>()) { if (!isa<NnModuleType>(op.getType())) {
auto instance = auto instance =
mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>(); mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
SlotOp affectedSlot; SlotOp affectedSlot;
@ -540,7 +540,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
Monomorphization monomorphization = std::move(*maybeMonomorphization); Monomorphization monomorphization = std::move(*maybeMonomorphization);
auto newArguments = llvm::to_vector<6>( auto newArguments = llvm::to_vector<6>(
llvm::make_filter_range(op->getOperands(), [](Value v) { llvm::make_filter_range(op->getOperands(), [](Value v) {
return !v.getType().isa<NnModuleType>(); return !isa<NnModuleType>(v.getType());
})); }));
assert(newFuncs.find(monomorphization) != newFuncs.end()); assert(newFuncs.find(monomorphization) != newFuncs.end());
auto newOp = OpBuilder(op).create<func::CallOp>( auto newOp = OpBuilder(op).create<func::CallOp>(
@ -564,7 +564,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
} }
llvm::BitVector argsToErase(func.getNumArguments()); llvm::BitVector argsToErase(func.getNumArguments());
for (auto type : llvm::enumerate(func.getArgumentTypes())) { for (auto type : llvm::enumerate(func.getArgumentTypes())) {
if (type.value().isa<NnModuleType>()) { if (isa<NnModuleType>(type.value())) {
argsToErase.set(type.index()); argsToErase.set(type.index());
} }
} }

View File

@ -248,8 +248,8 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
})) }))
continue; continue;
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) { if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
auto symName = initialize.getSlotSymNames()[use.getOperandNumber()] auto symName = cast<FlatSymbolRefAttr>(
.cast<FlatSymbolRefAttr>(); initialize.getSlotSymNames()[use.getOperandNumber()]);
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>( auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName)); value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
if (state->isSafe) if (state->isSafe)
@ -333,10 +333,10 @@ class InlineGlobalSlotsPass
DenseSet</*FlatSymbolRefAttr*/ Attribute> safeToInline; DenseSet</*FlatSymbolRefAttr*/ Attribute> safeToInline;
for (int i = 0, e = initialize->getNumOperands(); i != e; i++) { for (int i = 0, e = initialize->getNumOperands(); i != e; i++) {
auto slotSymName = auto slotSymName =
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>(); cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
Value operand = initialize.getOperand(i); Value operand = initialize.getOperand(i);
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>( auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>()); cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]));
auto *state = auto *state =
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint); solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
// We roll the analysis of whether a slot is set or public into the // We roll the analysis of whether a slot is set or public into the
@ -408,7 +408,7 @@ class InlineGlobalSlotsPass
SmallVector<Value> newInitialValues; SmallVector<Value> newInitialValues;
for (int i = 0, e = initialize.getNumOperands(); i != e; i++) { for (int i = 0, e = initialize.getNumOperands(); i != e; i++) {
auto slotSymName = auto slotSymName =
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>(); cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
if (!safeToInline.count(slotSymName)) { if (!safeToInline.count(slotSymName)) {
newSlotSymNames.push_back(slotSymName); newSlotSymNames.push_back(slotSymName);
newInitialValues.push_back(initialize.getOperand(i)); newInitialValues.push_back(initialize.getOperand(i));

View File

@ -118,7 +118,7 @@ static LogicalResult checkType(Operation *op, Type type,
if (auto optionalType = dyn_cast<OptionalType>(type)) { if (auto optionalType = dyn_cast<OptionalType>(type)) {
// TODO: Be stricter about tensor types. // TODO: Be stricter about tensor types.
// See comment below for ListType. // See comment below for ListType.
if (optionalType.getContainedType().isa<ValueTensorType>()) if (isa<ValueTensorType>(optionalType.getContainedType()))
return success(); return success();
return checkType(op, optionalType.getContainedType(), return checkType(op, optionalType.getContainedType(),
actuallyEmitDiagnostics); actuallyEmitDiagnostics);
@ -134,7 +134,7 @@ static LogicalResult checkType(Operation *op, Type type,
// the contained type information. Somehow this slips through and works. // the contained type information. Somehow this slips through and works.
// We should be stricter about this and properly infer the contained type // We should be stricter about this and properly infer the contained type
// and shape. // and shape.
if (listType.getContainedType().isa<ValueTensorType>()) if (isa<ValueTensorType>(listType.getContainedType()))
return success(); return success();
return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics); return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics);
} }
@ -535,7 +535,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
} }
target.addDynamicallyLegalOp<OperatorOp>( target.addDynamicallyLegalOp<OperatorOp>(
[backendLegalOpsSet](OperatorOp opOp) { [backendLegalOpsSet](OperatorOp opOp) {
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue(); auto opName = cast<StringAttr>(opOp->getAttr("name")).getValue();
return backendLegalOpsSet.contains(opName); return backendLegalOpsSet.contains(opName);
}); });
} }

View File

@ -62,7 +62,7 @@ public:
op.getLoc(), op.getOperand(0).getType(), op.getOperand(0), op.getLoc(), op.getOperand(0).getType(), op.getOperand(0),
op.getOperand(3), op.getOperand(4)); op.getOperand(3), op.getOperand(4));
auto clampTy = clamp.getType().cast<Torch::ValueTensorType>(); auto clampTy = cast<Torch::ValueTensorType>(clamp.getType());
if (!clampTy.hasDtype()) if (!clampTy.hasDtype())
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"dequantization has unknown dtype"); "dequantization has unknown dtype");

View File

@ -23,7 +23,7 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
static Value assertNonValueTensor(Value tensor) { static Value assertNonValueTensor(Value tensor) {
assert(tensor.getType().isa<NonValueTensorType>() && assert(isa<NonValueTensorType>(tensor.getType()) &&
"tensor is expected to be a non-value tensor"); "tensor is expected to be a non-value tensor");
return tensor; return tensor;
} }
@ -102,7 +102,7 @@ public:
// to use value semantics (which happens for example with ops // to use value semantics (which happens for example with ops
// that take two aliases as input), then it is possible that the // that take two aliases as input), then it is possible that the
// op no longer generates an alias. // op no longer generates an alias.
if (userResult.getType().isa<NonValueTensorType>()) if (isa<NonValueTensorType>(userResult.getType()))
availableAliases.insert(userResult); availableAliases.insert(userResult);
result.viewLikeOps.push_back(user); result.viewLikeOps.push_back(user);
} else if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) { } else if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
@ -177,7 +177,7 @@ public:
for (Operation *viewLikeOp : ops.viewLikeOps) { for (Operation *viewLikeOp : ops.viewLikeOps) {
rewriter.modifyOpInPlace(viewLikeOp, [&] { rewriter.modifyOpInPlace(viewLikeOp, [&] {
Value result = viewLikeOp->getResult(0); Value result = viewLikeOp->getResult(0);
auto resultType = result.getType().dyn_cast<NonValueTensorType>(); auto resultType = dyn_cast<NonValueTensorType>(result.getType());
if (resultType) if (resultType)
result.setType(resultType.getWithValueSemantics()); result.setType(resultType.getWithValueSemantics());
}); });
@ -230,7 +230,7 @@ public:
if (isViewLikeOp(op)) { if (isViewLikeOp(op)) {
// We currently only support view-like ops with one tensor output. // We currently only support view-like ops with one tensor output.
if (op->getNumResults() != 1 || if (op->getNumResults() != 1 ||
!op->getResult(0).getType().isa<BaseTensorType>()) { !isa<BaseTensorType>(op->getResult(0).getType())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
copy, "unsupported: view-like ops must have one tensor output, " copy, "unsupported: view-like ops must have one tensor output, "
"and the tensor output must be the first result"); "and the tensor output must be the first result");
@ -242,7 +242,7 @@ public:
// non-value tensor and the output being a value tensor. If this is the // non-value tensor and the output being a value tensor. If this is the
// case then there is no need to look at the users of the result of the // case then there is no need to look at the users of the result of the
// op. // op.
if (opResult.getType().isa<NonValueTensorType>()) { if (isa<NonValueTensorType>(opResult.getType())) {
if (operand.getOperandNumber() == 0) { if (operand.getOperandNumber() == 0) {
validViewLikeOps.insert(op); validViewLikeOps.insert(op);
llvm::append_range(workList, opResult.getUses()); llvm::append_range(workList, opResult.getUses());
@ -339,7 +339,7 @@ public:
for (Operation *op : viewLikeOps) { for (Operation *op : viewLikeOps) {
rewriter.modifyOpInPlace(op, [&]() { rewriter.modifyOpInPlace(op, [&]() {
if (auto nonValueTensorType = if (auto nonValueTensorType =
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) { dyn_cast<NonValueTensorType>(op->getResult(0).getType())) {
originalTypes[op->getResult(0)] = nonValueTensorType; originalTypes[op->getResult(0)] = nonValueTensorType;
op->getResult(0).setType(nonValueTensorType.getWithValueSemantics()); op->getResult(0).setType(nonValueTensorType.getWithValueSemantics());
} }

View File

@ -30,7 +30,7 @@ public:
LogicalResult matchAndRewrite(PrimCallMethodOp op, LogicalResult matchAndRewrite(PrimCallMethodOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto classType = symbolTable.lookup<ClassTypeOp>( auto classType = symbolTable.lookup<ClassTypeOp>(
op.getReceiver().getType().cast<NnModuleType>().getClassName()); cast<NnModuleType>(op.getReceiver().getType()).getClassName());
assert(classType && "malformed module -- missing ClassTypeOp"); assert(classType && "malformed module -- missing ClassTypeOp");
func::FuncOp func; func::FuncOp func;
for (auto method : classType.getOps<MethodOp>()) { for (auto method : classType.getOps<MethodOp>()) {
@ -94,7 +94,7 @@ class PrepareForGlobalizeObjectGraphPass
ConversionTarget target(*context); ConversionTarget target(*context);
target.addIllegalOp<PrimCallMethodOp>(); target.addIllegalOp<PrimCallMethodOp>();
target.addDynamicallyLegalOp<func::ConstantOp>( target.addDynamicallyLegalOp<func::ConstantOp>(
[](func::ConstantOp op) { return !op.getType().isa<FunctionType>(); }); [](func::ConstantOp op) { return !isa<FunctionType>(op.getType()); });
target.addIllegalOp<func::CallIndirectOp>(); target.addIllegalOp<func::CallIndirectOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

View File

@ -78,7 +78,7 @@ public:
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false); Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
// Create IndexPut_Op // Create IndexPut_Op
BaseTensorType tensorType = op.getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(op.getType());
Type rangeType = tensorType.getWithSizesAndDtype( Type rangeType = tensorType.getWithSizesAndDtype(
{kUnknownSize}, tensorType.getOptionalDtype()); {kUnknownSize}, tensorType.getOptionalDtype());
Value range = rewriter.create<AtenArangeStartStepOp>( Value range = rewriter.create<AtenArangeStartStepOp>(
@ -130,8 +130,7 @@ public:
// Create IndexPut_Op // Create IndexPut_Op
// Convert indexNum to indexTensor for the selectOp // Convert indexNum to indexTensor for the selectOp
BaseTensorType selectOutTy = BaseTensorType selectOutTy = cast<BaseTensorType>(selectOp.getType());
selectOp.getType().template cast<BaseTensorType>();
SmallVector<int64_t> empty; SmallVector<int64_t> empty;
auto dtype = getTypeForTorchType(selectOp.getContext(), auto dtype = getTypeForTorchType(selectOp.getContext(),
selectOp.getIndex().getType()); selectOp.getIndex().getType());
@ -141,7 +140,7 @@ public:
selectOp.getLoc(), emptyTensorType, selectOp.getIndex()); selectOp.getLoc(), emptyTensorType, selectOp.getIndex());
// Create indicesVector for IndexPut_Op by TorchNone and indexTensor // Create indicesVector for IndexPut_Op by TorchNone and indexTensor
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(op->getResultTypes()[0]);
SmallVector<Value> indicesVector(dim, noneVal); SmallVector<Value> indicesVector(dim, noneVal);
indicesVector.push_back(indexTensor); indicesVector.push_back(indexTensor);

View File

@ -26,8 +26,8 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter,
Location loc, Value overwriterTensor, Location loc, Value overwriterTensor,
Value overwrittenTensor) { Value overwrittenTensor) {
Type overwriterTensorType = overwriterTensor.getType(); Type overwriterTensorType = overwriterTensor.getType();
Type overwrittenTensorType = overwrittenTensor.getType() Type overwrittenTensorType =
.dyn_cast<NonValueTensorType>() dyn_cast<NonValueTensorType>(overwrittenTensor.getType())
.getWithValueSemantics(); .getWithValueSemantics();
if (overwriterTensorType != overwrittenTensorType) { if (overwriterTensorType != overwrittenTensorType) {
overwriterTensor = rewriter.create<TensorStaticInfoCastOp>( overwriterTensor = rewriter.create<TensorStaticInfoCastOp>(
@ -58,7 +58,7 @@ operatorOpHasValueSemantics(OperatorOp opOp,
std::optional<SymbolTable> extraLibrary) { std::optional<SymbolTable> extraLibrary) {
if (!extraLibrary.has_value()) if (!extraLibrary.has_value())
return false; return false;
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue(); auto opName = cast<StringAttr>(opOp->getAttr("name")).getValue();
std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix( std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix(
LibraryFunctionKind::HasValueSemantics) + LibraryFunctionKind::HasValueSemantics) +
Twine(opName)) Twine(opName))
@ -96,8 +96,8 @@ public:
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(), opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
opOperand.get())); opOperand.get()));
} else if (auto listType = dyn_cast<ListType>(operandType)) { } else if (auto listType = dyn_cast<ListType>(operandType)) {
if (!(listType.getContainedType().isa<NonValueTensorType>() || if (!(isa<NonValueTensorType>(listType.getContainedType()) ||
listType.getContainedType().isa<OptionalType>())) isa<OptionalType>(listType.getContainedType())))
continue; continue;
// Construct a new list whose elements are value tensors copied from // Construct a new list whose elements are value tensors copied from
@ -116,7 +116,7 @@ public:
// TODO: Handle optional type in list type. // TODO: Handle optional type in list type.
if (auto optionalType = if (auto optionalType =
listType.getContainedType().dyn_cast<OptionalType>()) { dyn_cast<OptionalType>(listType.getContainedType())) {
if (!llvm::all_of(listConstruct.getElements(), [](Value val) { if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
return val.getType().isa<NonValueTensorType, Torch::NoneType>(); return val.getType().isa<NonValueTensorType, Torch::NoneType>();
})) { })) {
@ -129,7 +129,7 @@ public:
auto newListElements = llvm::to_vector(llvm::map_range( auto newListElements = llvm::to_vector(llvm::map_range(
listConstruct.getElements(), [&](Value tensor) -> Value { listConstruct.getElements(), [&](Value tensor) -> Value {
if (tensor.getType().isa<NonValueTensorType>()) { if (isa<NonValueTensorType>(tensor.getType())) {
return rewriter.create<CopyToValueTensorOp>(op->getLoc(), return rewriter.create<CopyToValueTensorOp>(op->getLoc(),
tensor); tensor);
} }
@ -147,7 +147,7 @@ public:
} else if (auto optionalType = dyn_cast<OptionalType>(operandType)) { } else if (auto optionalType = dyn_cast<OptionalType>(operandType)) {
// TODO: A more general way to handle the optional type is to // TODO: A more general way to handle the optional type is to
// introduce a `copy.to_optional_vtensor` op. // introduce a `copy.to_optional_vtensor` op.
if (!optionalType.getContainedType().isa<NonValueTensorType>()) if (!isa<NonValueTensorType>(optionalType.getContainedType()))
continue; continue;
// Create a new optional value whose input is a value tensor copied // Create a new optional value whose input is a value tensor copied
@ -160,7 +160,7 @@ public:
"derefine"); "derefine");
} }
if (!derefine.getOperand().getType().isa<NonValueTensorType>()) if (!isa<NonValueTensorType>(derefine.getOperand().getType()))
continue; continue;
auto newOperand = rewriter.create<CopyToValueTensorOp>( auto newOperand = rewriter.create<CopyToValueTensorOp>(
op->getLoc(), derefine.getOperand()); op->getLoc(), derefine.getOperand());
@ -172,7 +172,7 @@ public:
// Convert all results. // Convert all results.
rewriter.setInsertionPointAfter(op); rewriter.setInsertionPointAfter(op);
for (Value result : op->getResults()) { for (Value result : op->getResults()) {
auto tensorType = result.getType().dyn_cast<NonValueTensorType>(); auto tensorType = dyn_cast<NonValueTensorType>(result.getType());
if (!tensorType) if (!tensorType)
continue; continue;
result.setType(tensorType.getWithValueSemantics()); result.setType(tensorType.getWithValueSemantics());

View File

@ -84,7 +84,7 @@ class RefinePublicReturnPass
} }
} }
if (auto tensorType = newOperand.getType().dyn_cast<BaseTensorType>()) { if (auto tensorType = dyn_cast<BaseTensorType>(newOperand.getType())) {
newOperands.push_back( newOperands.push_back(
copyTensorToType(builder, returnOp->getLoc(), copyTensorToType(builder, returnOp->getLoc(),
tensorType.getWithValueSemantics(), newOperand)); tensorType.getWithValueSemantics(), newOperand));

View File

@ -118,7 +118,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
assert(call.getNumResults() == 1 && assert(call.getNumResults() == 1 &&
"Multiple results are packed in a tuple in Python!"); "Multiple results are packed in a tuple in Python!");
Value result = call.getResult(0); Value result = call.getResult(0);
if (auto tupleType = result.getType().dyn_cast<Torch::TupleType>()) { if (auto tupleType = dyn_cast<Torch::TupleType>(result.getType())) {
auto unpack = b.create<PrimTupleUnpackOp>( auto unpack = b.create<PrimTupleUnpackOp>(
loc, tupleType.getContainedTypes(), result); loc, tupleType.getContainedTypes(), result);
llvm::append_range(unpackedResults, unpack.getResults()); llvm::append_range(unpackedResults, unpack.getResults());
@ -275,7 +275,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// for i in range(len(operand)): // for i in range(len(operand)):
// adjusted_list.append(adjust(operand[i])) // adjusted_list.append(adjust(operand[i]))
// return adjusted_list // return adjusted_list
auto providedType = operand.getType().cast<Torch::ListType>(); auto providedType = cast<Torch::ListType>(operand.getType());
Value adjustedList = Value adjustedList =
b.create<PrimListConstructOp>(loc, desiredListType, ValueRange({})); b.create<PrimListConstructOp>(loc, desiredListType, ValueRange({}));
// Create a for-like PrimLoopOp. // Create a for-like PrimLoopOp.
@ -312,7 +312,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// signature uses `Scalar` (see comments in torch_ods_gen.py for // signature uses `Scalar` (see comments in torch_ods_gen.py for
// explanation). // explanation).
if (isa<Torch::FloatType>(desiredType) && if (isa<Torch::FloatType>(desiredType) &&
operand.getType().isa<Torch::IntType>()) { isa<Torch::IntType>(operand.getType())) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult(); return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
} }

View File

@ -30,7 +30,7 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand, auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
Type desiredType) -> Value { Type desiredType) -> Value {
if (isa<Torch::TupleType>(desiredType) && if (isa<Torch::TupleType>(desiredType) &&
operand.getType().isa<Torch::BaseTensorType>()) { isa<Torch::BaseTensorType>(operand.getType())) {
Type intType = Torch::IntType::get(b.getContext()); Type intType = Torch::IntType::get(b.getContext());
Type sizeListType = Torch::ListType::get(intType); Type sizeListType = Torch::ListType::get(intType);
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand); Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);

View File

@ -41,8 +41,8 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
auto desiredListType = dyn_cast<Torch::ListType>(desiredType); auto desiredListType = dyn_cast<Torch::ListType>(desiredType);
if (!desiredListType) if (!desiredListType)
return operand; return operand;
if (operand.getType().isa<Torch::BaseTensorType>() && if (isa<Torch::BaseTensorType>(operand.getType()) &&
desiredListType.getContainedType().isa<Torch::IntType>()) { isa<Torch::IntType>(desiredListType.getContainedType())) {
return b.create<AtenSizeOp>(loc, desiredType, operand); return b.create<AtenSizeOp>(loc, desiredType, operand);
} }
return operand; return operand;

View File

@ -259,7 +259,7 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
Type originalResultType = result.getType(); Type originalResultType = result.getType();
Type updatedType; Type updatedType;
if (auto originalBaseTensorType = if (auto originalBaseTensorType =
originalResultType.template dyn_cast<BaseTensorType>()) { dyn_cast<BaseTensorType>(originalResultType)) {
// If we didn't get any new information, there is nothing left for us to do. // If we didn't get any new information, there is nothing left for us to do.
updatedType = meetTensorTypes(originalBaseTensorType, updatedType = meetTensorTypes(originalBaseTensorType,
cast<BaseTensorType>(newResultType)); cast<BaseTensorType>(newResultType));
@ -267,7 +267,7 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
calculateOp, "New type information does not refine old type"); calculateOp, "New type information does not refine old type");
} else if (auto originalResultType = } else if (auto originalResultType =
result.getType().template dyn_cast<Torch::NumberType>()) { dyn_cast<Torch::NumberType>(result.getType())) {
if (!isa<Torch::FloatType, Torch::IntType>(newResultType)) { if (!isa<Torch::FloatType, Torch::IntType>(newResultType)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
calculateOp, calculateOp,

View File

@ -35,7 +35,7 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
// Calculate the updated type incorporating the new information. // Calculate the updated type incorporating the new information.
Type impliedTypeFromDtype; Type impliedTypeFromDtype;
if (result.getType().isa<Torch::NumberType>()) { if (isa<Torch::NumberType>(result.getType())) {
FailureOr<Type> torchType = FailureOr<Type> torchType =
getTorchTypeForScalarType(op->getContext(), dtypeScalarType); getTorchTypeForScalarType(op->getContext(), dtypeScalarType);
if (failed(torchType)) { if (failed(torchType)) {
@ -45,7 +45,7 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
} }
impliedTypeFromDtype = *torchType; impliedTypeFromDtype = *torchType;
} else if (auto originalResultType = } else if (auto originalResultType =
result.getType().dyn_cast<BaseTensorType>()) { dyn_cast<BaseTensorType>(result.getType())) {
FailureOr<Type> builtinType = FailureOr<Type> builtinType =
getTypeForScalarType(op->getContext(), dtypeScalarType); getTypeForScalarType(op->getContext(), dtypeScalarType);
if (failed(builtinType)) { if (failed(builtinType)) {
@ -168,12 +168,12 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op, LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto originalResultType = op.getResult().getType().cast<BaseTensorType>(); auto originalResultType = cast<BaseTensorType>(op.getResult().getType());
if (originalResultType.hasDtype()) if (originalResultType.hasDtype())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "`PrimNumToTensorScalarOp` already has a dtype"); op, "`PrimNumToTensorScalarOp` already has a dtype");
if (op.getA().getType().isa<Torch::NumberType>()) { if (isa<Torch::NumberType>(op.getA().getType())) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"`PrimNumToTensorScalarOp`'s input " "`PrimNumToTensorScalarOp`'s input "
"should have concrete Scalar Type."); "should have concrete Scalar Type.");

View File

@ -27,7 +27,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value self = op.getSelf(); Value self = op.getSelf();
MLIRContext *context = op.getContext(); MLIRContext *context = op.getContext();
auto tensorType = self.getType().cast<BaseTensorType>(); auto tensorType = cast<BaseTensorType>(self.getType());
if (!tensorType.hasSizes()) if (!tensorType.hasSizes())
return rewriter.notifyMatchFailure(op, "unranked tensor"); return rewriter.notifyMatchFailure(op, "unranked tensor");
int64_t rank = tensorType.getSizes().size(); int64_t rank = tensorType.getSizes().size();
@ -96,7 +96,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
sizes.push_back(kUnknownSize); sizes.push_back(kUnknownSize);
} }
auto originalResultType = result.getType().cast<BaseTensorType>(); auto originalResultType = cast<BaseTensorType>(result.getType());
auto impliedTypesFromShape = auto impliedTypesFromShape =
cast<BaseTensorType>(originalResultType) cast<BaseTensorType>(originalResultType)
.getWithSizesAndDtype(ArrayRef(sizes), .getWithSizesAndDtype(ArrayRef(sizes),

View File

@ -44,9 +44,9 @@ bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
} }
torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
if (type.isa<Float32Type>()) if (isa<Float32Type>(type))
return torch_upstream::ScalarType::Float; return torch_upstream::ScalarType::Float;
if (type.isa<Float64Type>()) if (isa<Float64Type>(type))
return torch_upstream::ScalarType::Double; return torch_upstream::ScalarType::Double;
if (type.isSignedInteger(64)) if (type.isSignedInteger(64))
return torch_upstream::ScalarType::Long; return torch_upstream::ScalarType::Long;
@ -64,11 +64,11 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::Byte; return torch_upstream::ScalarType::Byte;
if (type.isSignedInteger(8)) if (type.isSignedInteger(8))
return torch_upstream::ScalarType::Char; return torch_upstream::ScalarType::Char;
if (type.isa<QUInt8Type>()) if (isa<QUInt8Type>(type))
return torch_upstream::ScalarType::QUInt8; return torch_upstream::ScalarType::QUInt8;
if (type.isa<QInt8Type>()) if (isa<QInt8Type>(type))
return torch_upstream::ScalarType::QInt8; return torch_upstream::ScalarType::QInt8;
if (type.isa<QInt32Type>()) if (isa<QInt32Type>(type))
return torch_upstream::ScalarType::QInt32; return torch_upstream::ScalarType::QInt32;
if (isa<ComplexType>(type)) { if (isa<ComplexType>(type)) {
mlir::Type complexElemType = cast<ComplexType>(type).getElementType(); mlir::Type complexElemType = cast<ComplexType>(type).getElementType();
@ -185,7 +185,7 @@ Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
// Helper to convert a tensor to a specific scalar type. // Helper to convert a tensor to a specific scalar type.
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
Value input, Type dtype) { Value input, Type dtype) {
BaseTensorType origType = input.getType().cast<BaseTensorType>(); BaseTensorType origType = cast<BaseTensorType>(input.getType());
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype); Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
// `convertIntVal` contains the corresponding integer for the dtype which is // `convertIntVal` contains the corresponding integer for the dtype which is
// used by the aten.to.dtype op. // used by the aten.to.dtype op.
@ -202,7 +202,7 @@ bool Torch::isBuiltInType(Type type) {
} }
std::optional<unsigned> Torch::getTensorRank(Value tensor) { std::optional<unsigned> Torch::getTensorRank(Value tensor) {
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>(); BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
if (!tensorType.hasSizes()) if (!tensorType.hasSizes())
return std::nullopt; return std::nullopt;
return tensorType.getSizes().size(); return tensorType.getSizes().size();
@ -279,7 +279,7 @@ SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
// Return the squeezed tensor or failure. // Return the squeezed tensor or failure.
FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
Location loc, int64_t dim, Value input) { Location loc, int64_t dim, Value input) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes()) { if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(loc, "input tensor must have size"); return rewriter.notifyMatchFailure(loc, "input tensor must have size");
} }
@ -314,7 +314,7 @@ FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
// Return the unsqueezed tensor or failure. // Return the unsqueezed tensor or failure.
FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter, FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
Operation *op, Value input, Value dim) { Operation *op, Value input, Value dim) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>(); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes()) { if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(op, "input tensor must have size"); return rewriter.notifyMatchFailure(op, "input tensor must have size");
} }
@ -348,9 +348,9 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
SmallVector<int64_t> &resultShape, SmallVector<int64_t> &resultShape,
SmallVector<Value> &resultShapeValue) { SmallVector<Value> &resultShapeValue) {
SmallVector<int64_t> shapeA{ SmallVector<int64_t> shapeA{
inputA.getType().cast<BaseTensorType>().getSizes()}; cast<BaseTensorType>(inputA.getType()).getSizes()};
SmallVector<int64_t> shapeB{ SmallVector<int64_t> shapeB{
inputB.getType().cast<BaseTensorType>().getSizes()}; cast<BaseTensorType>(inputB.getType()).getSizes()};
unsigned rankA = shapeA.size(); unsigned rankA = shapeA.size();
unsigned rankB = shapeB.size(); unsigned rankB = shapeB.size();
unsigned minRank = rankA > rankB ? rankB : rankA; unsigned minRank = rankA > rankB ? rankB : rankA;
@ -504,9 +504,8 @@ Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar) { BaseTensorType inputType, Value scalar) {
assert(inputType.hasDtype() && "input must have dtype"); assert(inputType.hasDtype() && "input must have dtype");
SmallVector<int64_t> sizes; SmallVector<int64_t> sizes;
BaseTensorType rank0TensorTy = BaseTensorType rank0TensorTy = cast<BaseTensorType>(
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()) inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()));
.cast<BaseTensorType>();
Value dimList = rewriter.create<PrimListConstructOp>( Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
ValueRange{}); ValueRange{});
@ -531,9 +530,9 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
return rewriter.getF32Type(); return rewriter.getF32Type();
if (inputType.isBF16()) if (inputType.isBF16())
return rewriter.getF32Type(); return rewriter.getF32Type();
if (inputType.isa<Float32Type>()) if (isa<Float32Type>(inputType))
return rewriter.getF32Type(); return rewriter.getF32Type();
if (inputType.isa<Float64Type>()) if (isa<Float64Type>(inputType))
return rewriter.getF64Type(); return rewriter.getF64Type();
if (inputType.isFloat8E5M2()) if (inputType.isFloat8E5M2())
return rewriter.getF32Type(); return rewriter.getF32Type();

View File

@ -34,9 +34,9 @@ static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult ToBuiltinTensorOp::verify() { LogicalResult ToBuiltinTensorOp::verify() {
auto resultType = getResult().getType().cast<TensorType>(); auto resultType = cast<TensorType>(getResult().getType());
auto operandType = auto operandType =
getOperand().getType().cast<Torch::ValueTensorType>().toBuiltinTensor(); cast<Torch::ValueTensorType>(getOperand().getType()).toBuiltinTensor();
if (!haveSameSizeAndElementType(resultType, operandType)) { if (!haveSameSizeAndElementType(resultType, operandType)) {
return emitError() return emitError()
<< "operand and result must have the same size and dtype"; << "operand and result must have the same size and dtype";
@ -49,7 +49,7 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) { SmallVectorImpl<Type> &inferredReturnTypes) {
auto resultType = auto resultType =
operands[0].getType().cast<Torch::ValueTensorType>().toBuiltinTensor(); cast<Torch::ValueTensorType>(operands[0].getType()).toBuiltinTensor();
if (!resultType) if (!resultType)
return failure(); return failure();
inferredReturnTypes.push_back(resultType); inferredReturnTypes.push_back(resultType);
@ -62,8 +62,8 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
LogicalResult FromBuiltinTensorOp::verify() { LogicalResult FromBuiltinTensorOp::verify() {
auto resultType = auto resultType =
getResult().getType().cast<Torch::ValueTensorType>().toBuiltinTensor(); cast<Torch::ValueTensorType>(getResult().getType()).toBuiltinTensor();
auto operandType = getOperand().getType().cast<TensorType>(); auto operandType = cast<TensorType>(getOperand().getType());
if (!haveSameSizeAndElementType(resultType, operandType)) { if (!haveSameSizeAndElementType(resultType, operandType)) {
return emitError() return emitError()
<< "operand and result must have the same size and dtype"; << "operand and result must have the same size and dtype";

View File

@ -36,7 +36,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
ValueRange inputs, ValueRange inputs,
Location loc) -> Value { Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
if (!inputs[0].getType().isa<Torch::BaseTensorType>()) if (!isa<Torch::BaseTensorType>(inputs[0].getType()))
return {}; return {};
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]); return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
}); });
@ -44,7 +44,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
Torch::ValueTensorType type, Torch::ValueTensorType type,
ValueRange inputs, Location loc) -> Value { ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<TensorType>()); assert(isa<TensorType>(inputs[0].getType()));
return builder.create<FromBuiltinTensorOp>(loc, type, inputs[0]); return builder.create<FromBuiltinTensorOp>(loc, type, inputs[0]);
}; };
typeConverter.addSourceMaterialization(sourceMaterialization); typeConverter.addSourceMaterialization(sourceMaterialization);
@ -64,13 +64,13 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target,
if (!(type.getWidth() == 1 && type.isSignless())) if (!(type.getWidth() == 1 && type.isSignless()))
return std::nullopt; return std::nullopt;
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<Torch::BoolType>()); assert(isa<Torch::BoolType>(inputs[0].getType()));
return builder.create<ToI1Op>(loc, inputs[0]).getResult(); return builder.create<ToI1Op>(loc, inputs[0]).getResult();
}); });
auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type, auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type,
ValueRange inputs, Location loc) -> Value { ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<IntegerType>()); assert(isa<IntegerType>(inputs[0].getType()));
return builder.create<FromI1Op>(loc, inputs[0]); return builder.create<FromI1Op>(loc, inputs[0]);
}; };
typeConverter.addSourceMaterialization(sourceMaterialization); typeConverter.addSourceMaterialization(sourceMaterialization);
@ -99,7 +99,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type, auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type,
ValueRange inputs, Location loc) -> Value { ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<IntegerType>()); assert(isa<IntegerType>(inputs[0].getType()));
return builder.create<FromI64Op>(loc, inputs[0]); return builder.create<FromI64Op>(loc, inputs[0]);
}; };
typeConverter.addSourceMaterialization(sourceMaterialization); typeConverter.addSourceMaterialization(sourceMaterialization);
@ -116,13 +116,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target,
[](OpBuilder &builder, Float64Type type, ValueRange inputs, [](OpBuilder &builder, Float64Type type, ValueRange inputs,
Location loc) -> std::optional<Value> { Location loc) -> std::optional<Value> {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<Torch::FloatType>()); assert(isa<Torch::FloatType>(inputs[0].getType()));
return builder.create<ToF64Op>(loc, inputs[0]).getResult(); return builder.create<ToF64Op>(loc, inputs[0]).getResult();
}); });
auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type, auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type,
ValueRange inputs, Location loc) -> Value { ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<Float64Type>()); assert(isa<Float64Type>(inputs[0].getType()));
return builder.create<FromF64Op>(loc, inputs[0]); return builder.create<FromF64Op>(loc, inputs[0]);
}; };
typeConverter.addSourceMaterialization(sourceMaterialization); typeConverter.addSourceMaterialization(sourceMaterialization);
@ -153,7 +153,7 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type, auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type,
ValueRange inputs, Location loc) -> Value { ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<IntegerType>()); assert(isa<IntegerType>(inputs[0].getType()));
return builder.create<I64ToGeneratorOp>(loc, inputs[0]); return builder.create<I64ToGeneratorOp>(loc, inputs[0]);
}; };
typeConverter.addSourceMaterialization(sourceMaterialization); typeConverter.addSourceMaterialization(sourceMaterialization);

View File

@ -42,7 +42,7 @@ public:
// get inputs: lhs, rhsQuant, scales, zps // get inputs: lhs, rhsQuant, scales, zps
Value lhs = adaptor.getOperands()[0]; Value lhs = adaptor.getOperands()[0];
auto lhsType = lhs.getType().cast<RankedTensorType>(); auto lhsType = cast<RankedTensorType>(lhs.getType());
if (!lhsType) { if (!lhsType) {
return failure(); return failure();
} }
@ -50,7 +50,7 @@ public:
int lhsReductDimSize = lhsShape.back(); int lhsReductDimSize = lhsShape.back();
Value rhsQuant = adaptor.getOperands()[1]; Value rhsQuant = adaptor.getOperands()[1];
auto rhsType = rhsQuant.getType().cast<RankedTensorType>(); auto rhsType = cast<RankedTensorType>(rhsQuant.getType());
if (!rhsType) { if (!rhsType) {
return failure(); return failure();
} }

View File

@ -59,7 +59,7 @@ public:
if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth))) if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth)))
return failure(); return failure();
auto rhsType = rhs.getType().dyn_cast<ValueTensorType>(); auto rhsType = dyn_cast<ValueTensorType>(rhs.getType());
if (!rhsType) if (!rhsType)
return failure(); return failure();
@ -88,7 +88,7 @@ public:
ValueTensorType newRhsType = ValueTensorType::get( ValueTensorType newRhsType = ValueTensorType::get(
rewriter.getContext(), tensorShape, unpackedElementType); rewriter.getContext(), tensorShape, unpackedElementType);
auto elements = constOp.getValueAttr().dyn_cast<DenseIntElementsAttr>(); auto elements = dyn_cast<DenseIntElementsAttr>(constOp.getValueAttr());
if (!elements) if (!elements)
return failure(); return failure();

View File

@ -234,7 +234,7 @@ static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp,
if (!globalOp.getValue().has_value()) if (!globalOp.getValue().has_value())
return globalOp.emitError("global op must have a value"); return globalOp.emitError("global op must have a value");
RankedTensorType tensorType = globalOp.getType().cast<RankedTensorType>(); RankedTensorType tensorType = cast<RankedTensorType>(globalOp.getType());
MemRefType memrefType = MemRefType memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType()); MemRefType::get(tensorType.getShape(), tensorType.getElementType());
@ -252,7 +252,7 @@ static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp,
static LogicalResult static LogicalResult
bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp, bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp,
OpBuilder &b, SmallVector<Operation *> &toErase) { OpBuilder &b, SmallVector<Operation *> &toErase) {
RankedTensorType tensorType = globalLoadOp.getType().cast<RankedTensorType>(); RankedTensorType tensorType = cast<RankedTensorType>(globalLoadOp.getType());
MemRefType memrefType = MemRefType memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType()); MemRefType::get(tensorType.getShape(), tensorType.getElementType());
@ -271,7 +271,7 @@ bufferizeMLProgramGlobaStoreOp(ml_program::GlobalStoreOp globalStoreOp,
OpBuilder &b, OpBuilder &b,
SmallVector<Operation *> &toErase) { SmallVector<Operation *> &toErase) {
RankedTensorType tensorType = RankedTensorType tensorType =
globalStoreOp.getValue().getType().cast<RankedTensorType>(); cast<RankedTensorType>(globalStoreOp.getValue().getType());
MemRefType memrefType = MemRefType memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType()); MemRefType::get(tensorType.getShape(), tensorType.getElementType());
@ -300,7 +300,7 @@ class MLProgramBufferize : public MLProgramBufferizeBase<MLProgramBufferize> {
SmallVector<Operation *> toErase; SmallVector<Operation *> toErase;
auto walkResult = module.walk([&](ml_program::GlobalOp op) { auto walkResult = module.walk([&](ml_program::GlobalOp op) {
if (auto type = op.getType().dyn_cast<RankedTensorType>()) { if (auto type = dyn_cast<RankedTensorType>(op.getType())) {
if (!type.hasStaticShape()) { if (!type.hasStaticShape()) {
// If the ml_program.global has dynamically shaped tensor. // If the ml_program.global has dynamically shaped tensor.
op.emitError( op.emitError(
@ -387,8 +387,8 @@ mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from, Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
Value to) { Value to) {
auto memrefTypeFrom = from.getType().cast<MemRefType>(); auto memrefTypeFrom = cast<MemRefType>(from.getType());
auto memrefTypeTo = to.getType().cast<MemRefType>(); auto memrefTypeTo = cast<MemRefType>(to.getType());
(void)memrefTypeFrom; (void)memrefTypeFrom;
assert(memrefTypeFrom && memrefTypeTo && assert(memrefTypeFrom && memrefTypeTo &&
memrefTypeFrom.getRank() == memrefTypeTo.getRank()); memrefTypeFrom.getRank() == memrefTypeTo.getRank());