Skip to content

Commit 2dadf6c

Browse files
committed
Go: introduce special GoOutputWriter and move generation of error check to it
This will allow to override this method and do not generate check in KST when performing checks for returned errors using `asserts[i].exception` key
1 parent 128e179 commit 2dadf6c

File tree

3 files changed

+37
-23
lines changed

3 files changed

+37
-23
lines changed

jvm/src/test/scala/io/kaitai/struct/translators/TranslatorSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ class TranslatorSpec extends AnyFunSpec {
950950
eo = Some(Expressions.parse(src))
951951
}
952952

953-
val goOutput = new StringLanguageOutputWriter(" ")
953+
val goOutput = new GoOutputWriter(" ")
954954

955955
val langs = ListMap[LanguageCompilerStatic, AbstractTranslator with TypeDetector](
956956
CppCompiler -> new CppTranslator(tp, new CppImportList(), new CppImportList(), RuntimeConfig()),

shared/src/main/scala/io/kaitai/struct/languages/GoCompiler.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import io.kaitai.struct.datatype._
55
import io.kaitai.struct.exprlang.Ast
66
import io.kaitai.struct.format._
77
import io.kaitai.struct.languages.components._
8-
import io.kaitai.struct.translators.{GoTranslator, ResultString, TranslatorResult}
8+
import io.kaitai.struct.translators.{GoOutputWriter, GoTranslator, ResultString, TranslatorResult}
99
import io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, Utils}
1010

1111
class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
@@ -19,6 +19,7 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
1919
with GoReads {
2020
import GoCompiler._
2121

22+
override val out = new GoOutputWriter(indent)
2223
override val translator = new GoTranslator(out, typeProvider, importList)
2324

2425
override def innerClasses = false
@@ -270,21 +271,21 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
270271

271272
override def pushPos(io: String): Unit = {
272273
out.puts(s"_pos, err := $io.Pos()")
273-
translator.outAddErrCheck()
274+
out.putsErrCheck(translator.returnRes)
274275
}
275276

276277
override def seek(io: String, pos: Ast.expr): Unit = {
277278
importList.add("io")
278279

279280
out.puts(s"_, err = $io.Seek(int64(${expression(pos)}), io.SeekStart)")
280-
translator.outAddErrCheck()
281+
out.putsErrCheck(translator.returnRes)
281282
}
282283

283284
override def popPos(io: String): Unit = {
284285
importList.add("io")
285286

286287
out.puts(s"_, err = $io.Seek(_pos, io.SeekStart)")
287-
translator.outAddErrCheck()
288+
out.putsErrCheck(translator.returnRes)
288289
}
289290

290291
override def alignToByte(io: String): Unit =
@@ -306,7 +307,7 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
306307

307308
val eofVar = translator.allocateLocalVar()
308309
out.puts(s"${translator.localVarName(eofVar)}, err := this._io.EOF()")
309-
translator.outAddErrCheck()
310+
out.putsErrCheck(translator.returnRes)
310311
out.puts(s"if ${translator.localVarName(eofVar)} {")
311312
out.inc
312313
out.puts("break")

shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,30 @@ sealed trait TranslatorResult
1515
case class ResultString(s: String) extends TranslatorResult
1616
case class ResultLocalVar(n: Int) extends TranslatorResult
1717

18-
class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, importList: ImportList)
18+
class GoOutputWriter(indentStr: String) extends StringLanguageOutputWriter(indentStr) {
19+
/**
20+
* Puts to the output code to check variable `err` for error value and emit
21+
* a premature return with value of error.
22+
*
23+
* @param result If not none this value will be returned as the first value of
24+
* the returned tuple, otherwise only `err` is returned
25+
*/
26+
def putsErrCheck(result: Option[String]): Unit = {
27+
puts("if err != nil {")
28+
inc
29+
30+
val noValueAndErr = result match {
31+
case None => "err"
32+
case Some(r) => s"$r, err"
33+
}
34+
35+
puts(s"return $noValueAndErr")
36+
dec
37+
puts("}")
38+
}
39+
}
40+
41+
class GoTranslator(out: GoOutputWriter, provider: TypeProvider, importList: ImportList)
1942
extends TypeDetector(provider)
2043
with AbstractTranslator
2144
with CommonLiterals
@@ -26,6 +49,10 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo
2649

2750
import io.kaitai.struct.languages.GoCompiler._
2851

52+
/**
53+
* Dummy return value that should be returned in case of error just because
54+
* we cannot return nothing.
55+
*/
2956
var returnRes: Option[String] = None
3057

3158
override def translate(v: Ast.expr, extPrec: Int): String = resToStr(translateExpr(v, extPrec))
@@ -470,7 +497,7 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo
470497
val addParams = t.args.map((a) => translate(a)).mkString(", ")
471498
out.puts(s"${localVarName(v)} := New${GoCompiler.types2class(t.classSpec.get.name)}($addParams)")
472499
out.puts(s"err = ${localVarName(v)}.Read($io, $parent, $root)")
473-
outAddErrCheck()
500+
out.putsErrCheck(returnRes)
474501
ResultLocalVar(v)
475502
}
476503

@@ -502,7 +529,7 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo
502529
def outVarCheckRes(expr: String): ResultLocalVar = {
503530
val v1 = allocateLocalVar()
504531
out.puts(s"${localVarName(v1)}, err := $expr")
505-
outAddErrCheck()
532+
out.putsErrCheck(returnRes)
506533
ResultLocalVar(v1)
507534
}
508535

@@ -521,20 +548,6 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo
521548

522549
def localVarName(n: Int) = s"tmp$n"
523550

524-
def outAddErrCheck(): Unit = {
525-
out.puts("if err != nil {")
526-
out.inc
527-
528-
val noValueAndErr = returnRes match {
529-
case None => "err"
530-
case Some(r) => s"$r, err"
531-
}
532-
533-
out.puts(s"return $noValueAndErr")
534-
out.dec
535-
out.puts("}")
536-
}
537-
538551
override def byteSizeOfValue(attrName: String, valType: DataType): TranslatorResult =
539552
trIntLiteral(CommonSizeOf.bitToByteSize(CommonSizeOf.getBitsSizeOfType(attrName, valType)))
540553
}

0 commit comments

Comments
 (0)