[MLIR][Torch] Resolve styling issues related to aten zeros/ones op

https://github.com/llvm/torch-mlir/pull/464#discussion_r765065092

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/470/head
Vivek Khandelwal 2021-12-09 17:31:17 +05:30
parent f34eb66124
commit 0a0a1b4476
1 changed files with 15 additions and 13 deletions

View File

@ -3329,41 +3329,43 @@ struct ConvertAtenOnesZerosOp : ConversionPattern {
return failure(); return failure();
Location loc = op->getLoc(); Location loc = op->getLoc();
SmallVector<Value, 3> opArguments; Value size, layout, pin_memory;
int64_t elementValue; int64_t elementValue;
if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) { if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
opArguments.insert(opArguments.end(), size = onesOp.size();
{onesOp.size(), onesOp.layout(), onesOp.pin_memory()}); layout = onesOp.layout();
pin_memory = onesOp.pin_memory();
elementValue = 1; elementValue = 1;
} else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) { } else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
opArguments.insert(opArguments.end(), {zerosOp.size(), zerosOp.layout(), size = zerosOp.size();
zerosOp.pin_memory()}); layout = zerosOp.layout();
pin_memory = zerosOp.pin_memory();
elementValue = 0; elementValue = 0;
} }
// We ignore device, but add simple asserts for unimplemented kwargs // We ignore device, but add simple asserts for unimplemented kwargs
if (!opArguments[1].getType().isa<Torch::NoneType>()) if (!layout.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only default layout is supported"); "only default layout is supported");
bool pinMemory = false; bool pinMemory = false;
if (!opArguments[2].getType().isa<Torch::NoneType>() && if (!pin_memory.getType().isa<Torch::NoneType>() &&
!matchPattern(opArguments[2], m_TorchConstantBool(&pinMemory))) { !matchPattern(pin_memory, m_TorchConstantBool(&pinMemory))) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "pin_memory must be constant bool or None"); op, "pin_memory must be constant bool or None");
} }
if (pinMemory) if (pinMemory)
return rewriter.notifyMatchFailure(op, "memory pinning not supported"); return rewriter.notifyMatchFailure(op, "memory pinning not supported");
SmallVector<Value> size, sizeIndex; SmallVector<Value> sizes, sizeIndex;
if (!getListConstructElements(opArguments[0], size)) { if (!getListConstructElements(size, sizes)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "size must be created by ListConstruct"); op, "size must be created by ListConstruct");
} }
size = getTypeConvertedValues(rewriter, loc, getTypeConverter(), size); sizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), sizes);
for (size_t i = 0; i < size.size(); i++) for (size_t i = 0; i < sizes.size(); i++)
sizeIndex.push_back(castIntToIndex(rewriter, loc, size[i])); sizeIndex.push_back(castIntToIndex(rewriter, loc, sizes[i]));
RankedTensorType newResultType = RankedTensorType newResultType =
getTypeConverter() getTypeConverter()