Adds result types to a prelu decomp (#3098)

This adds explicit result types instead of relying on shape/dtype
computations.

Solves a regression issue with IREE: #3092
pull/3100/head
zjgarvey 2024-04-02 13:41:56 -05:00 committed by GitHub
parent 6cbb2f7ae0
commit 40e762ca42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 5 deletions

View File

@ -2382,15 +2382,15 @@ public:
Location loc = op.getLoc();
Value input = op.getSelf();
Value weight = op.getWeight();
auto resType = op.getType().cast<BaseTensorType>();
auto baseType =
ValueTensorType::getWithLeastStaticInformation(op.getContext());
auto resType = op.getType().cast<ValueTensorType>();
auto boolTensorType = rewriter.getType<ValueTensorType>(
resType.getOptionalSizes(), rewriter.getI1Type());
Value zero =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value inputMulWeight =
rewriter.create<AtenMulTensorOp>(loc, baseType, input, weight);
rewriter.create<AtenMulTensorOp>(loc, resType, input, weight);
Value lessThanZero =
rewriter.create<AtenLtScalarOp>(loc, baseType, input, zero);
rewriter.create<AtenLtScalarOp>(loc, boolTensorType, input, zero);
Value preluOutput = rewriter.create<AtenWhereSelfOp>(
loc, resType, lessThanZero, inputMulWeight, input);